Unverified Commit dd31208a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[ci] [python-package] enable ruff-format on all Python code (#6336)

parent 2a085655
...@@ -23,14 +23,33 @@ except ImportError: ...@@ -23,14 +23,33 @@ except ImportError:
pass pass
_version_path = Path(__file__).absolute().parent / 'VERSION.txt' _version_path = Path(__file__).absolute().parent / "VERSION.txt"
if _version_path.is_file(): if _version_path.is_file():
__version__ = _version_path.read_text(encoding='utf-8').strip() __version__ = _version_path.read_text(encoding="utf-8").strip()
__all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence', __all__ = [
'register_logger', "Dataset",
'train', 'cv', "Booster",
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', "CVBooster",
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker', "Sequence",
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'EarlyStopException', "register_logger",
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph'] "train",
"cv",
"LGBMModel",
"LGBMRegressor",
"LGBMClassifier",
"LGBMRanker",
"DaskLGBMRegressor",
"DaskLGBMClassifier",
"DaskLGBMRanker",
"log_evaluation",
"record_evaluation",
"reset_parameter",
"early_stopping",
"EarlyStopException",
"plot_importance",
"plot_split_value_histogram",
"plot_metric",
"plot_tree",
"create_tree_digraph",
]
...@@ -48,31 +48,31 @@ if TYPE_CHECKING: ...@@ -48,31 +48,31 @@ if TYPE_CHECKING:
__all__ = [ __all__ = [
'Booster', "Booster",
'Dataset', "Dataset",
'LGBMDeprecationWarning', "LGBMDeprecationWarning",
'LightGBMError', "LightGBMError",
'register_logger', "register_logger",
'Sequence', "Sequence",
] ]
_BoosterHandle = ctypes.c_void_p _BoosterHandle = ctypes.c_void_p
_DatasetHandle = ctypes.c_void_p _DatasetHandle = ctypes.c_void_p
_ctypes_int_ptr = Union[ _ctypes_int_ptr = Union[
"ctypes._Pointer[ctypes.c_int32]", "ctypes._Pointer[ctypes.c_int32]",
"ctypes._Pointer[ctypes.c_int64]" "ctypes._Pointer[ctypes.c_int64]",
] ]
_ctypes_int_array = Union[ _ctypes_int_array = Union[
"ctypes.Array[ctypes._Pointer[ctypes.c_int32]]", "ctypes.Array[ctypes._Pointer[ctypes.c_int32]]",
"ctypes.Array[ctypes._Pointer[ctypes.c_int64]]" "ctypes.Array[ctypes._Pointer[ctypes.c_int64]]",
] ]
_ctypes_float_ptr = Union[ _ctypes_float_ptr = Union[
"ctypes._Pointer[ctypes.c_float]", "ctypes._Pointer[ctypes.c_float]",
"ctypes._Pointer[ctypes.c_double]" "ctypes._Pointer[ctypes.c_double]",
] ]
_ctypes_float_array = Union[ _ctypes_float_array = Union[
"ctypes.Array[ctypes._Pointer[ctypes.c_float]]", "ctypes.Array[ctypes._Pointer[ctypes.c_float]]",
"ctypes.Array[ctypes._Pointer[ctypes.c_double]]" "ctypes.Array[ctypes._Pointer[ctypes.c_double]]",
] ]
_LGBM_EvalFunctionResultType = Tuple[str, float, bool] _LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]] _LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
...@@ -90,7 +90,7 @@ _LGBM_GroupType = Union[ ...@@ -90,7 +90,7 @@ _LGBM_GroupType = Union[
] ]
_LGBM_PositionType = Union[ _LGBM_PositionType = Union[
np.ndarray, np.ndarray,
pd_Series pd_Series,
] ]
_LGBM_InitScoreType = Union[ _LGBM_InitScoreType = Union[
List[float], List[float],
...@@ -112,7 +112,7 @@ _LGBM_TrainDataType = Union[ ...@@ -112,7 +112,7 @@ _LGBM_TrainDataType = Union[
"Sequence", "Sequence",
List["Sequence"], List["Sequence"],
List[np.ndarray], List[np.ndarray],
pa_Table pa_Table,
] ]
_LGBM_LabelType = Union[ _LGBM_LabelType = Union[
List[float], List[float],
...@@ -140,6 +140,19 @@ _LGBM_WeightType = Union[ ...@@ -140,6 +140,19 @@ _LGBM_WeightType = Union[
pa_Array, pa_Array,
pa_ChunkedArray, pa_ChunkedArray,
] ]
_LGBM_SetFieldType = Union[
List[List[float]],
List[List[int]],
List[float],
List[int],
np.ndarray,
pd_Series,
pd_DataFrame,
pa_Table,
pa_Array,
pa_ChunkedArray,
]
ZERO_THRESHOLD = 1e-35 ZERO_THRESHOLD = 1e-35
...@@ -149,18 +162,20 @@ def _is_zero(x: float) -> bool: ...@@ -149,18 +162,20 @@ def _is_zero(x: float) -> bool:
def _get_sample_count(total_nrow: int, params: str) -> int: def _get_sample_count(total_nrow: int, params: str) -> int:
sample_cnt = ctypes.c_int(0) sample_cnt = ctypes.c_int(0)
_safe_call(_LIB.LGBM_GetSampleCount( _safe_call(
ctypes.c_int32(total_nrow), _LIB.LGBM_GetSampleCount(
_c_str(params), ctypes.c_int32(total_nrow),
ctypes.byref(sample_cnt), _c_str(params),
)) ctypes.byref(sample_cnt),
)
)
return sample_cnt.value return sample_cnt.value
class _MissingType(Enum): class _MissingType(Enum):
NONE = 'None' NONE = "None"
NAN = 'NaN' NAN = "NaN"
ZERO = 'Zero' ZERO = "Zero"
class _DummyLogger: class _DummyLogger:
...@@ -181,7 +196,9 @@ def _has_method(logger: Any, method_name: str) -> bool: ...@@ -181,7 +196,9 @@ def _has_method(logger: Any, method_name: str) -> bool:
def register_logger( def register_logger(
logger: Any, info_method_name: str = "info", warning_method_name: str = "warning" logger: Any,
info_method_name: str = "info",
warning_method_name: str = "warning",
) -> None: ) -> None:
"""Register custom logger. """Register custom logger.
...@@ -195,9 +212,7 @@ def register_logger( ...@@ -195,9 +212,7 @@ def register_logger(
Method used to log warning messages. Method used to log warning messages.
""" """
if not _has_method(logger, info_method_name) or not _has_method(logger, warning_method_name): if not _has_method(logger, info_method_name) or not _has_method(logger, warning_method_name):
raise TypeError( raise TypeError(f"Logger must provide '{info_method_name}' and '{warning_method_name}' method")
f"Logger must provide '{info_method_name}' and '{warning_method_name}' method"
)
global _LOGGER, _INFO_METHOD_NAME, _WARNING_METHOD_NAME global _LOGGER, _INFO_METHOD_NAME, _WARNING_METHOD_NAME
_LOGGER = logger _LOGGER = logger
...@@ -212,8 +227,8 @@ def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], Non ...@@ -212,8 +227,8 @@ def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], Non
@wraps(func) @wraps(func)
def wrapper(msg: str) -> None: def wrapper(msg: str) -> None:
nonlocal msg_normalized nonlocal msg_normalized
if msg.strip() == '': if msg.strip() == "":
msg = ''.join(msg_normalized) msg = "".join(msg_normalized)
msg_normalized = [] msg_normalized = []
return func(msg) return func(msg)
else: else:
...@@ -237,7 +252,7 @@ def _log_native(msg: str) -> None: ...@@ -237,7 +252,7 @@ def _log_native(msg: str) -> None:
def _log_callback(msg: bytes) -> None: def _log_callback(msg: bytes) -> None:
"""Redirect logs from native library into Python.""" """Redirect logs from native library into Python."""
_log_native(str(msg.decode('utf-8'))) _log_native(str(msg.decode("utf-8")))
def _load_lib() -> ctypes.CDLL: def _load_lib() -> ctypes.CDLL:
...@@ -248,14 +263,15 @@ def _load_lib() -> ctypes.CDLL: ...@@ -248,14 +263,15 @@ def _load_lib() -> ctypes.CDLL:
callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p) callback = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
lib.callback = callback(_log_callback) # type: ignore[attr-defined] lib.callback = callback(_log_callback) # type: ignore[attr-defined]
if lib.LGBM_RegisterLogCallback(lib.callback) != 0: if lib.LGBM_RegisterLogCallback(lib.callback) != 0:
raise LightGBMError(lib.LGBM_GetLastError().decode('utf-8')) raise LightGBMError(lib.LGBM_GetLastError().decode("utf-8"))
return lib return lib
# we don't need lib_lightgbm while building docs # we don't need lib_lightgbm while building docs
_LIB: ctypes.CDLL _LIB: ctypes.CDLL
if environ.get('LIGHTGBM_BUILD_DOC', False): if environ.get("LIGHTGBM_BUILD_DOC", False):
from unittest.mock import Mock # isort: skip from unittest.mock import Mock # isort: skip
_LIB = Mock(ctypes.CDLL) # type: ignore _LIB = Mock(ctypes.CDLL) # type: ignore
else: else:
_LIB = _load_lib() _LIB = _load_lib()
...@@ -273,7 +289,7 @@ def _safe_call(ret: int) -> None: ...@@ -273,7 +289,7 @@ def _safe_call(ret: int) -> None:
The return value from C API calls. The return value from C API calls.
""" """
if ret != 0: if ret != 0:
raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8')) raise LightGBMError(_LIB.LGBM_GetLastError().decode("utf-8"))
def _is_numeric(obj: Any) -> bool: def _is_numeric(obj: Any) -> bool:
...@@ -313,39 +329,28 @@ def _is_1d_list(data: Any) -> bool: ...@@ -313,39 +329,28 @@ def _is_1d_list(data: Any) -> bool:
def _is_list_of_numpy_arrays(data: Any) -> "TypeGuard[List[np.ndarray]]": def _is_list_of_numpy_arrays(data: Any) -> "TypeGuard[List[np.ndarray]]":
return ( return isinstance(data, list) and all(isinstance(x, np.ndarray) for x in data)
isinstance(data, list)
and all(isinstance(x, np.ndarray) for x in data)
)
def _is_list_of_sequences(data: Any) -> "TypeGuard[List[Sequence]]": def _is_list_of_sequences(data: Any) -> "TypeGuard[List[Sequence]]":
return ( return isinstance(data, list) and all(isinstance(x, Sequence) for x in data)
isinstance(data, list)
and all(isinstance(x, Sequence) for x in data)
)
def _is_1d_collection(data: Any) -> bool: def _is_1d_collection(data: Any) -> bool:
"""Check whether data is a 1-D collection.""" """Check whether data is a 1-D collection."""
return ( return _is_numpy_1d_array(data) or _is_numpy_column_array(data) or _is_1d_list(data) or isinstance(data, pd_Series)
_is_numpy_1d_array(data)
or _is_numpy_column_array(data)
or _is_1d_list(data)
or isinstance(data, pd_Series)
)
def _list_to_1d_numpy( def _list_to_1d_numpy(
data: Any, data: Any,
dtype: "np.typing.DTypeLike", dtype: "np.typing.DTypeLike",
name: str name: str,
) -> np.ndarray: ) -> np.ndarray:
"""Convert data to numpy 1-D array.""" """Convert data to numpy 1-D array."""
if _is_numpy_1d_array(data): if _is_numpy_1d_array(data):
return _cast_numpy_array_to_dtype(data, dtype) return _cast_numpy_array_to_dtype(data, dtype)
elif _is_numpy_column_array(data): elif _is_numpy_column_array(data):
_log_warning('Converting column-vector to 1d array') _log_warning("Converting column-vector to 1d array")
array = data.ravel() array = data.ravel()
return _cast_numpy_array_to_dtype(array, dtype) return _cast_numpy_array_to_dtype(array, dtype)
elif _is_1d_list(data): elif _is_1d_list(data):
...@@ -354,8 +359,9 @@ def _list_to_1d_numpy( ...@@ -354,8 +359,9 @@ def _list_to_1d_numpy(
_check_for_bad_pandas_dtypes(data.to_frame().dtypes) _check_for_bad_pandas_dtypes(data.to_frame().dtypes)
return np.array(data, dtype=dtype, copy=False) # SparseArray should be supported as well return np.array(data, dtype=dtype, copy=False) # SparseArray should be supported as well
else: else:
raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n" raise TypeError(
"It should be list, numpy 1-D array or pandas Series") f"Wrong type({type(data).__name__}) for {name}.\n" "It should be list, numpy 1-D array or pandas Series"
)
def _is_numpy_2d_array(data: Any) -> bool: def _is_numpy_2d_array(data: Any) -> bool:
...@@ -370,11 +376,7 @@ def _is_2d_list(data: Any) -> bool: ...@@ -370,11 +376,7 @@ def _is_2d_list(data: Any) -> bool:
def _is_2d_collection(data: Any) -> bool: def _is_2d_collection(data: Any) -> bool:
"""Check whether data is a 2-D collection.""" """Check whether data is a 2-D collection."""
return ( return _is_numpy_2d_array(data) or _is_2d_list(data) or isinstance(data, pd_DataFrame)
_is_numpy_2d_array(data)
or _is_2d_list(data)
or isinstance(data, pd_DataFrame)
)
def _is_pyarrow_array(data: Any) -> "TypeGuard[Union[pa_Array, pa_ChunkedArray]]": def _is_pyarrow_array(data: Any) -> "TypeGuard[Union[pa_Array, pa_ChunkedArray]]":
...@@ -438,11 +440,10 @@ def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray: ...@@ -438,11 +440,10 @@ def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray:
return _ArrowCArray(len(chunks), chunks, schema) return _ArrowCArray(len(chunks), chunks, schema)
def _data_to_2d_numpy( def _data_to_2d_numpy(
data: Any, data: Any,
dtype: "np.typing.DTypeLike", dtype: "np.typing.DTypeLike",
name: str name: str,
) -> np.ndarray: ) -> np.ndarray:
"""Convert data to numpy 2-D array.""" """Convert data to numpy 2-D array."""
if _is_numpy_2d_array(data): if _is_numpy_2d_array(data):
...@@ -452,8 +453,10 @@ def _data_to_2d_numpy( ...@@ -452,8 +453,10 @@ def _data_to_2d_numpy(
if isinstance(data, pd_DataFrame): if isinstance(data, pd_DataFrame):
_check_for_bad_pandas_dtypes(data.dtypes) _check_for_bad_pandas_dtypes(data.dtypes)
return _cast_numpy_array_to_dtype(data.values, dtype) return _cast_numpy_array_to_dtype(data.values, dtype)
raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n" raise TypeError(
"It should be list of lists, numpy 2-D array or pandas DataFrame") f"Wrong type({type(data).__name__}) for {name}.\n"
"It should be list of lists, numpy 2-D array or pandas DataFrame"
)
def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
...@@ -461,7 +464,7 @@ def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndar ...@@ -461,7 +464,7 @@ def _cfloat32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndar
if isinstance(cptr, ctypes.POINTER(ctypes.c_float)): if isinstance(cptr, ctypes.POINTER(ctypes.c_float)):
return np.ctypeslib.as_array(cptr, shape=(length,)).copy() return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
else: else:
raise RuntimeError('Expected float pointer') raise RuntimeError("Expected float pointer")
def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
...@@ -469,7 +472,7 @@ def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndar ...@@ -469,7 +472,7 @@ def _cfloat64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndar
if isinstance(cptr, ctypes.POINTER(ctypes.c_double)): if isinstance(cptr, ctypes.POINTER(ctypes.c_double)):
return np.ctypeslib.as_array(cptr, shape=(length,)).copy() return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
else: else:
raise RuntimeError('Expected double pointer') raise RuntimeError("Expected double pointer")
def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
...@@ -477,7 +480,7 @@ def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarra ...@@ -477,7 +480,7 @@ def _cint32_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarra
if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)): if isinstance(cptr, ctypes.POINTER(ctypes.c_int32)):
return np.ctypeslib.as_array(cptr, shape=(length,)).copy() return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
else: else:
raise RuntimeError('Expected int32 pointer') raise RuntimeError("Expected int32 pointer")
def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray: def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarray:
...@@ -485,12 +488,12 @@ def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarra ...@@ -485,12 +488,12 @@ def _cint64_array_to_numpy(*, cptr: "ctypes._Pointer", length: int) -> np.ndarra
if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)): if isinstance(cptr, ctypes.POINTER(ctypes.c_int64)):
return np.ctypeslib.as_array(cptr, shape=(length,)).copy() return np.ctypeslib.as_array(cptr, shape=(length,)).copy()
else: else:
raise RuntimeError('Expected int64 pointer') raise RuntimeError("Expected int64 pointer")
def _c_str(string: str) -> ctypes.c_char_p: def _c_str(string: str) -> ctypes.c_char_p:
"""Convert a Python string to C string.""" """Convert a Python string to C string."""
return ctypes.c_char_p(string.encode('utf-8')) return ctypes.c_char_p(string.encode("utf-8"))
def _c_array(ctype: type, values: List[Any]) -> ctypes.Array: def _c_array(ctype: type, values: List[Any]) -> ctypes.Array:
...@@ -527,8 +530,8 @@ def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str: ...@@ -527,8 +530,8 @@ def _param_dict_to_str(data: Optional[Dict[str, Any]]) -> str:
elif isinstance(val, (str, Path, _NUMERIC_TYPES)) or _is_numeric(val): elif isinstance(val, (str, Path, _NUMERIC_TYPES)) or _is_numeric(val):
pairs.append(f"{key}={val}") pairs.append(f"{key}={val}")
elif val is not None: elif val is not None:
raise TypeError(f'Unknown type of parameter:{key}, got:{type(val).__name__}') raise TypeError(f"Unknown type of parameter:{key}, got:{type(val).__name__}")
return ' '.join(pairs) return " ".join(pairs)
class _TempFile: class _TempFile:
...@@ -568,22 +571,27 @@ class _ConfigAliases: ...@@ -568,22 +571,27 @@ class _ConfigAliases:
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len) string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer)) ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(_LIB.LGBM_DumpParamAliases( _safe_call(
ctypes.c_int64(buffer_len), _LIB.LGBM_DumpParamAliases(
ctypes.byref(tmp_out_len), ctypes.c_int64(buffer_len),
ptr_string_buffer)) ctypes.byref(tmp_out_len),
ptr_string_buffer,
)
)
actual_len = tmp_out_len.value actual_len = tmp_out_len.value
# if buffer length is not long enough, re-allocate a buffer # if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len: if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len) string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer)) ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(_LIB.LGBM_DumpParamAliases( _safe_call(
ctypes.c_int64(actual_len), _LIB.LGBM_DumpParamAliases(
ctypes.byref(tmp_out_len), ctypes.c_int64(actual_len),
ptr_string_buffer)) ctypes.byref(tmp_out_len),
ptr_string_buffer,
)
)
return json.loads( return json.loads(
string_buffer.value.decode('utf-8'), string_buffer.value.decode("utf-8"), object_hook=lambda obj: {k: [k] + v for k, v in obj.items()}
object_hook=lambda obj: {k: [k] + v for k, v in obj.items()}
) )
@classmethod @classmethod
...@@ -692,13 +700,13 @@ _FIELD_TYPE_MAPPER = { ...@@ -692,13 +700,13 @@ _FIELD_TYPE_MAPPER = {
"weight": _C_API_DTYPE_FLOAT32, "weight": _C_API_DTYPE_FLOAT32,
"init_score": _C_API_DTYPE_FLOAT64, "init_score": _C_API_DTYPE_FLOAT64,
"group": _C_API_DTYPE_INT32, "group": _C_API_DTYPE_INT32,
"position": _C_API_DTYPE_INT32 "position": _C_API_DTYPE_INT32,
} }
"""String name to int feature importance type mapper""" """String name to int feature importance type mapper"""
_FEATURE_IMPORTANCE_TYPE_MAPPER = { _FEATURE_IMPORTANCE_TYPE_MAPPER = {
"split": _C_API_FEATURE_IMPORTANCE_SPLIT, "split": _C_API_FEATURE_IMPORTANCE_SPLIT,
"gain": _C_API_FEATURE_IMPORTANCE_GAIN "gain": _C_API_FEATURE_IMPORTANCE_GAIN,
} }
...@@ -706,15 +714,15 @@ def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray: ...@@ -706,15 +714,15 @@ def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray:
"""Fix the memory of multi-dimensional sliced object.""" """Fix the memory of multi-dimensional sliced object."""
if isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray): if isinstance(data, np.ndarray) and isinstance(data.base, np.ndarray):
if not data.flags.c_contiguous: if not data.flags.c_contiguous:
_log_warning("Usage of np.ndarray subset (sliced data) is not recommended " _log_warning(
"due to it will double the peak memory cost in LightGBM.") "Usage of np.ndarray subset (sliced data) is not recommended "
"due to it will double the peak memory cost in LightGBM."
)
return np.copy(data) return np.copy(data)
return data return data
def _c_float_array( def _c_float_array(data: np.ndarray) -> Tuple[_ctypes_float_ptr, int, np.ndarray]:
data: np.ndarray
) -> Tuple[_ctypes_float_ptr, int, np.ndarray]:
"""Get pointer of float numpy array / list.""" """Get pointer of float numpy array / list."""
if _is_1d_list(data): if _is_1d_list(data):
data = np.array(data, copy=False) data = np.array(data, copy=False)
...@@ -735,9 +743,7 @@ def _c_float_array( ...@@ -735,9 +743,7 @@ def _c_float_array(
return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed
def _c_int_array( def _c_int_array(data: np.ndarray) -> Tuple[_ctypes_int_ptr, int, np.ndarray]:
data: np.ndarray
) -> Tuple[_ctypes_int_ptr, int, np.ndarray]:
"""Get pointer of int numpy array / list.""" """Get pointer of int numpy array / list."""
if _is_1d_list(data): if _is_1d_list(data):
data = np.array(data, copy=False) data = np.array(data, copy=False)
...@@ -759,27 +765,26 @@ def _c_int_array( ...@@ -759,27 +765,26 @@ def _c_int_array(
def _is_allowed_numpy_dtype(dtype: type) -> bool: def _is_allowed_numpy_dtype(dtype: type) -> bool:
float128 = getattr(np, 'float128', type(None)) float128 = getattr(np, "float128", type(None))
return ( return issubclass(dtype, (np.integer, np.floating, np.bool_)) and not issubclass(dtype, (np.timedelta64, float128))
issubclass(dtype, (np.integer, np.floating, np.bool_))
and not issubclass(dtype, (np.timedelta64, float128))
)
def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None: def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
bad_pandas_dtypes = [ bad_pandas_dtypes = [
f'{column_name}: {pandas_dtype}' f"{column_name}: {pandas_dtype}"
for column_name, pandas_dtype in pandas_dtypes_series.items() for column_name, pandas_dtype in pandas_dtypes_series.items()
if not _is_allowed_numpy_dtype(pandas_dtype.type) if not _is_allowed_numpy_dtype(pandas_dtype.type)
] ]
if bad_pandas_dtypes: if bad_pandas_dtypes:
raise ValueError('pandas dtypes must be int, float or bool.\n' raise ValueError(
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}') 'pandas dtypes must be int, float or bool.\n'
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}'
)
def _pandas_to_numpy( def _pandas_to_numpy(
data: pd_DataFrame, data: pd_DataFrame,
target_dtype: "np.typing.DTypeLike" target_dtype: "np.typing.DTypeLike",
) -> np.ndarray: ) -> np.ndarray:
_check_for_bad_pandas_dtypes(data.dtypes) _check_for_bad_pandas_dtypes(data.dtypes)
try: try:
...@@ -798,17 +803,17 @@ def _data_from_pandas( ...@@ -798,17 +803,17 @@ def _data_from_pandas(
data: pd_DataFrame, data: pd_DataFrame,
feature_name: _LGBM_FeatureNameConfiguration, feature_name: _LGBM_FeatureNameConfiguration,
categorical_feature: _LGBM_CategoricalFeatureConfiguration, categorical_feature: _LGBM_CategoricalFeatureConfiguration,
pandas_categorical: Optional[List[List]] pandas_categorical: Optional[List[List]],
) -> Tuple[np.ndarray, List[str], Union[List[str], List[int]], List[List]]: ) -> Tuple[np.ndarray, List[str], Union[List[str], List[int]], List[List]]:
if len(data.shape) != 2 or data.shape[0] < 1: if len(data.shape) != 2 or data.shape[0] < 1:
raise ValueError('Input data must be 2 dimensional and non empty.') raise ValueError("Input data must be 2 dimensional and non empty.")
# take shallow copy in case we modify categorical columns # take shallow copy in case we modify categorical columns
# whole column modifications don't change the original df # whole column modifications don't change the original df
data = data.copy(deep=False) data = data.copy(deep=False)
# determine feature names # determine feature names
if feature_name == 'auto': if feature_name == "auto":
feature_name = [str(col) for col in data.columns] feature_name = [str(col) for col in data.columns]
# determine categorical features # determine categorical features
...@@ -818,7 +823,7 @@ def _data_from_pandas( ...@@ -818,7 +823,7 @@ def _data_from_pandas(
pandas_categorical = [list(data[col].cat.categories) for col in cat_cols] pandas_categorical = [list(data[col].cat.categories) for col in cat_cols]
else: else:
if len(cat_cols) != len(pandas_categorical): if len(cat_cols) != len(pandas_categorical):
raise ValueError('train and valid dataset categorical_feature do not match.') raise ValueError("train and valid dataset categorical_feature do not match.")
for col, category in zip(cat_cols, pandas_categorical): for col, category in zip(cat_cols, pandas_categorical):
if list(data[col].cat.categories) != list(category): if list(data[col].cat.categories) != list(category):
data[col] = data[col].cat.set_categories(category) data[col] = data[col].cat.set_categories(category)
...@@ -826,7 +831,7 @@ def _data_from_pandas( ...@@ -826,7 +831,7 @@ def _data_from_pandas(
data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan}) data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes).replace({-1: np.nan})
# use cat cols from DataFrame # use cat cols from DataFrame
if categorical_feature == 'auto': if categorical_feature == "auto":
categorical_feature = cat_cols_not_ordered categorical_feature = cat_cols_not_ordered
df_dtypes = [dtype.type for dtype in data.dtypes] df_dtypes = [dtype.type for dtype in data.dtypes]
...@@ -838,31 +843,31 @@ def _data_from_pandas( ...@@ -838,31 +843,31 @@ def _data_from_pandas(
_pandas_to_numpy(data, target_dtype=target_dtype), _pandas_to_numpy(data, target_dtype=target_dtype),
feature_name, feature_name,
categorical_feature, categorical_feature,
pandas_categorical pandas_categorical,
) )
def _dump_pandas_categorical( def _dump_pandas_categorical(
pandas_categorical: Optional[List[List]], pandas_categorical: Optional[List[List]],
file_name: Optional[Union[str, Path]] = None file_name: Optional[Union[str, Path]] = None,
) -> str: ) -> str:
categorical_json = json.dumps(pandas_categorical, default=_json_default_with_numpy) categorical_json = json.dumps(pandas_categorical, default=_json_default_with_numpy)
pandas_str = f'\npandas_categorical:{categorical_json}\n' pandas_str = f"\npandas_categorical:{categorical_json}\n"
if file_name is not None: if file_name is not None:
with open(file_name, 'a') as f: with open(file_name, "a") as f:
f.write(pandas_str) f.write(pandas_str)
return pandas_str return pandas_str
def _load_pandas_categorical( def _load_pandas_categorical(
file_name: Optional[Union[str, Path]] = None, file_name: Optional[Union[str, Path]] = None,
model_str: Optional[str] = None model_str: Optional[str] = None,
) -> Optional[List[List]]: ) -> Optional[List[List]]:
pandas_key = 'pandas_categorical:' pandas_key = "pandas_categorical:"
offset = -len(pandas_key) offset = -len(pandas_key)
if file_name is not None: if file_name is not None:
max_offset = -getsize(file_name) max_offset = -getsize(file_name)
with open(file_name, 'rb') as f: with open(file_name, "rb") as f:
while True: while True:
if offset < max_offset: if offset < max_offset:
offset = max_offset offset = max_offset
...@@ -871,14 +876,14 @@ def _load_pandas_categorical( ...@@ -871,14 +876,14 @@ def _load_pandas_categorical(
if len(lines) >= 2: if len(lines) >= 2:
break break
offset *= 2 offset *= 2
last_line = lines[-1].decode('utf-8').strip() last_line = lines[-1].decode("utf-8").strip()
if not last_line.startswith(pandas_key): if not last_line.startswith(pandas_key):
last_line = lines[-2].decode('utf-8').strip() last_line = lines[-2].decode("utf-8").strip()
elif model_str is not None: elif model_str is not None:
idx = model_str.rfind('\n', 0, offset) idx = model_str.rfind("\n", 0, offset)
last_line = model_str[idx:].strip() last_line = model_str[idx:].strip()
if last_line.startswith(pandas_key): if last_line.startswith(pandas_key):
return json.loads(last_line[len(pandas_key):]) return json.loads(last_line[len(pandas_key) :])
else: else:
return None return None
...@@ -965,7 +970,7 @@ class _InnerPredictor: ...@@ -965,7 +970,7 @@ class _InnerPredictor:
booster_handle: _BoosterHandle, booster_handle: _BoosterHandle,
pandas_categorical: Optional[List[List]], pandas_categorical: Optional[List[List]],
pred_parameter: Dict[str, Any], pred_parameter: Dict[str, Any],
manage_handle: bool manage_handle: bool,
): ):
"""Initialize the _InnerPredictor. """Initialize the _InnerPredictor.
...@@ -990,7 +995,7 @@ class _InnerPredictor: ...@@ -990,7 +995,7 @@ class _InnerPredictor:
_safe_call( _safe_call(
_LIB.LGBM_BoosterGetNumClasses( _LIB.LGBM_BoosterGetNumClasses(
self._handle, self._handle,
ctypes.byref(out_num_class) ctypes.byref(out_num_class),
) )
) )
self.num_class = out_num_class.value self.num_class = out_num_class.value
...@@ -999,7 +1004,7 @@ class _InnerPredictor: ...@@ -999,7 +1004,7 @@ class _InnerPredictor:
def from_booster( def from_booster(
cls, cls,
booster: "Booster", booster: "Booster",
pred_parameter: Dict[str, Any] pred_parameter: Dict[str, Any],
) -> "_InnerPredictor": ) -> "_InnerPredictor":
"""Initialize an ``_InnerPredictor`` from a ``Booster``. """Initialize an ``_InnerPredictor`` from a ``Booster``.
...@@ -1014,21 +1019,21 @@ class _InnerPredictor: ...@@ -1014,21 +1019,21 @@ class _InnerPredictor:
_safe_call( _safe_call(
_LIB.LGBM_BoosterGetCurrentIteration( _LIB.LGBM_BoosterGetCurrentIteration(
booster._handle, booster._handle,
ctypes.byref(out_cur_iter) ctypes.byref(out_cur_iter),
) )
) )
return cls( return cls(
booster_handle=booster._handle, booster_handle=booster._handle,
pandas_categorical=booster.pandas_categorical, pandas_categorical=booster.pandas_categorical,
pred_parameter=pred_parameter, pred_parameter=pred_parameter,
manage_handle=False manage_handle=False,
) )
@classmethod @classmethod
def from_model_file( def from_model_file(
cls, cls,
model_file: Union[str, Path], model_file: Union[str, Path],
pred_parameter: Dict[str, Any] pred_parameter: Dict[str, Any],
) -> "_InnerPredictor": ) -> "_InnerPredictor":
"""Initialize an ``_InnerPredictor`` from a text file containing a LightGBM model. """Initialize an ``_InnerPredictor`` from a text file containing a LightGBM model.
...@@ -1045,14 +1050,14 @@ class _InnerPredictor: ...@@ -1045,14 +1050,14 @@ class _InnerPredictor:
_LIB.LGBM_BoosterCreateFromModelfile( _LIB.LGBM_BoosterCreateFromModelfile(
_c_str(str(model_file)), _c_str(str(model_file)),
ctypes.byref(out_num_iterations), ctypes.byref(out_num_iterations),
ctypes.byref(booster_handle) ctypes.byref(booster_handle),
) )
) )
return cls( return cls(
booster_handle=booster_handle, booster_handle=booster_handle,
pandas_categorical=_load_pandas_categorical(file_name=model_file), pandas_categorical=_load_pandas_categorical(file_name=model_file),
pred_parameter=pred_parameter, pred_parameter=pred_parameter,
manage_handle=True manage_handle=True,
) )
def __del__(self) -> None: def __del__(self) -> None:
...@@ -1064,8 +1069,8 @@ class _InnerPredictor: ...@@ -1064,8 +1069,8 @@ class _InnerPredictor:
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:
this = self.__dict__.copy() this = self.__dict__.copy()
this.pop('handle', None) this.pop("handle", None)
this.pop('_handle', None) this.pop("_handle", None)
return this return this
def predict( def predict(
...@@ -1077,7 +1082,7 @@ class _InnerPredictor: ...@@ -1077,7 +1082,7 @@ class _InnerPredictor:
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
data_has_header: bool = False, data_has_header: bool = False,
validate_features: bool = False validate_features: bool = False,
) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]: ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
"""Predict logic. """Predict logic.
...@@ -1116,7 +1121,7 @@ class _InnerPredictor: ...@@ -1116,7 +1121,7 @@ class _InnerPredictor:
elif isinstance(data, pd_DataFrame) and validate_features: elif isinstance(data, pd_DataFrame) and validate_features:
data_names = [str(x) for x in data.columns] data_names = [str(x) for x in data.columns]
ptr_names = (ctypes.c_char_p * len(data_names))() ptr_names = (ctypes.c_char_p * len(data_names))()
ptr_names[:] = [x.encode('utf-8') for x in data_names] ptr_names[:] = [x.encode("utf-8") for x in data_names]
_safe_call( _safe_call(
_LIB.LGBM_BoosterValidateFeatureNames( _LIB.LGBM_BoosterValidateFeatureNames(
self._handle, self._handle,
...@@ -1130,7 +1135,7 @@ class _InnerPredictor: ...@@ -1130,7 +1135,7 @@ class _InnerPredictor:
data=data, data=data,
feature_name="auto", feature_name="auto",
categorical_feature="auto", categorical_feature="auto",
pandas_categorical=self.pandas_categorical pandas_categorical=self.pandas_categorical,
)[0] )[0]
predict_type = _C_API_PREDICT_NORMAL predict_type = _C_API_PREDICT_NORMAL
...@@ -1143,15 +1148,18 @@ class _InnerPredictor: ...@@ -1143,15 +1148,18 @@ class _InnerPredictor:
if isinstance(data, (str, Path)): if isinstance(data, (str, Path)):
with _TempFile() as f: with _TempFile() as f:
_safe_call(_LIB.LGBM_BoosterPredictForFile( _safe_call(
self._handle, _LIB.LGBM_BoosterPredictForFile(
_c_str(str(data)), self._handle,
ctypes.c_int(data_has_header), _c_str(str(data)),
ctypes.c_int(predict_type), ctypes.c_int(data_has_header),
ctypes.c_int(start_iteration), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
_c_str(self.pred_parameter), ctypes.c_int(num_iteration),
_c_str(f.name))) _c_str(self.pred_parameter),
_c_str(f.name),
)
)
preds = np.loadtxt(f.name, dtype=np.float64) preds = np.loadtxt(f.name, dtype=np.float64)
nrow = preds.shape[0] nrow = preds.shape[0]
elif isinstance(data, scipy.sparse.csr_matrix): elif isinstance(data, scipy.sparse.csr_matrix):
...@@ -1159,58 +1167,58 @@ class _InnerPredictor: ...@@ -1159,58 +1167,58 @@ class _InnerPredictor:
csr=data, csr=data,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
elif isinstance(data, scipy.sparse.csc_matrix): elif isinstance(data, scipy.sparse.csc_matrix):
preds, nrow = self.__pred_for_csc( preds, nrow = self.__pred_for_csc(
csc=data, csc=data,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
preds, nrow = self.__pred_for_np2d( preds, nrow = self.__pred_for_np2d(
mat=data, mat=data,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
elif _is_pyarrow_table(data): elif _is_pyarrow_table(data):
preds, nrow = self.__pred_for_pyarrow_table( preds, nrow = self.__pred_for_pyarrow_table(
table=data, table=data,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
elif isinstance(data, list): elif isinstance(data, list):
try: try:
data = np.array(data) data = np.array(data)
except BaseException as err: except BaseException as err:
raise ValueError('Cannot convert data list to numpy array.') from err raise ValueError("Cannot convert data list to numpy array.") from err
preds, nrow = self.__pred_for_np2d( preds, nrow = self.__pred_for_np2d(
mat=data, mat=data,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
elif isinstance(data, dt_DataTable): elif isinstance(data, dt_DataTable):
preds, nrow = self.__pred_for_np2d( preds, nrow = self.__pred_for_np2d(
mat=data.to_numpy(), mat=data.to_numpy(),
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
else: else:
try: try:
_log_warning('Converting data to scipy sparse matrix.') _log_warning("Converting data to scipy sparse matrix.")
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
except BaseException as err: except BaseException as err:
raise TypeError(f'Cannot predict data for type {type(data).__name__}') from err raise TypeError(f"Cannot predict data for type {type(data).__name__}") from err
preds, nrow = self.__pred_for_csr( preds, nrow = self.__pred_for_csr(
csr=csr, csr=csr,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
if pred_leaf: if pred_leaf:
preds = preds.astype(np.int32) preds = preds.astype(np.int32)
...@@ -1219,7 +1227,7 @@ class _InnerPredictor: ...@@ -1219,7 +1227,7 @@ class _InnerPredictor:
if preds.size % nrow == 0: if preds.size % nrow == 0:
preds = preds.reshape(nrow, -1) preds = preds.reshape(nrow, -1)
else: else:
raise ValueError(f'Length of predict result ({preds.size}) cannot be divide nrow ({nrow})') raise ValueError(f"Length of predict result ({preds.size}) cannot be divide nrow ({nrow})")
return preds return preds
def __get_num_preds( def __get_num_preds(
...@@ -1227,22 +1235,27 @@ class _InnerPredictor: ...@@ -1227,22 +1235,27 @@ class _InnerPredictor:
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
nrow: int, nrow: int,
predict_type: int predict_type: int,
) -> int: ) -> int:
"""Get size of prediction result.""" """Get size of prediction result."""
if nrow > _MAX_INT32: if nrow > _MAX_INT32:
raise LightGBMError('LightGBM cannot perform prediction for data ' raise LightGBMError(
f'with number of rows greater than MAX_INT32 ({_MAX_INT32}).\n' "LightGBM cannot perform prediction for data "
'You can split your data into chunks ' f"with number of rows greater than MAX_INT32 ({_MAX_INT32}).\n"
'and then concatenate predictions for them') "You can split your data into chunks "
"and then concatenate predictions for them"
)
n_preds = ctypes.c_int64(0) n_preds = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterCalcNumPredict( _safe_call(
self._handle, _LIB.LGBM_BoosterCalcNumPredict(
ctypes.c_int(nrow), self._handle,
ctypes.c_int(predict_type), ctypes.c_int(nrow),
ctypes.c_int(start_iteration), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
ctypes.byref(n_preds))) ctypes.c_int(num_iteration),
ctypes.byref(n_preds),
)
)
return n_preds.value return n_preds.value
def __inner_predict_np2d( def __inner_predict_np2d(
...@@ -1251,7 +1264,7 @@ class _InnerPredictor: ...@@ -1251,7 +1264,7 @@ class _InnerPredictor:
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int, predict_type: int,
preds: Optional[np.ndarray] preds: Optional[np.ndarray],
) -> Tuple[np.ndarray, int]: ) -> Tuple[np.ndarray, int]:
if mat.dtype == np.float32 or mat.dtype == np.float64: if mat.dtype == np.float32 or mat.dtype == np.float64:
data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False) data = np.array(mat.reshape(mat.size), dtype=mat.dtype, copy=False)
...@@ -1262,26 +1275,29 @@ class _InnerPredictor: ...@@ -1262,26 +1275,29 @@ class _InnerPredictor:
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
nrow=mat.shape[0], nrow=mat.shape[0],
predict_type=predict_type predict_type=predict_type,
) )
if preds is None: if preds is None:
preds = np.empty(n_preds, dtype=np.float64) preds = np.empty(n_preds, dtype=np.float64)
elif len(preds.shape) != 1 or len(preds) != n_preds: elif len(preds.shape) != 1 or len(preds) != n_preds:
raise ValueError("Wrong length of pre-allocated predict array") raise ValueError("Wrong length of pre-allocated predict array")
out_num_preds = ctypes.c_int64(0) out_num_preds = ctypes.c_int64(0)
_safe_call(_LIB.LGBM_BoosterPredictForMat( _safe_call(
self._handle, _LIB.LGBM_BoosterPredictForMat(
ptr_data, self._handle,
ctypes.c_int(type_ptr_data), ptr_data,
ctypes.c_int32(mat.shape[0]), ctypes.c_int(type_ptr_data),
ctypes.c_int32(mat.shape[1]), ctypes.c_int32(mat.shape[0]),
ctypes.c_int(_C_API_IS_ROW_MAJOR), ctypes.c_int32(mat.shape[1]),
ctypes.c_int(predict_type), ctypes.c_int(_C_API_IS_ROW_MAJOR),
ctypes.c_int(start_iteration), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
_c_str(self.pred_parameter), ctypes.c_int(num_iteration),
ctypes.byref(out_num_preds), _c_str(self.pred_parameter),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
)
)
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
raise ValueError("Wrong length for predict results") raise ValueError("Wrong length for predict results")
return preds, mat.shape[0] return preds, mat.shape[0]
...@@ -1291,28 +1307,32 @@ class _InnerPredictor: ...@@ -1291,28 +1307,32 @@ class _InnerPredictor:
mat: np.ndarray, mat: np.ndarray,
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int predict_type: int,
) -> Tuple[np.ndarray, int]: ) -> Tuple[np.ndarray, int]:
"""Predict for a 2-D numpy matrix.""" """Predict for a 2-D numpy matrix."""
if len(mat.shape) != 2: if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray or list must be 2 dimensional') raise ValueError("Input numpy.ndarray or list must be 2 dimensional")
nrow = mat.shape[0] nrow = mat.shape[0]
if nrow > _MAX_INT32: if nrow > _MAX_INT32:
sections = np.arange(start=_MAX_INT32, stop=nrow, step=_MAX_INT32) sections = np.arange(start=_MAX_INT32, stop=nrow, step=_MAX_INT32)
# __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal # __get_num_preds() cannot work with nrow > MAX_INT32, so calculate overall number of predictions piecemeal
n_preds = [self.__get_num_preds(start_iteration, num_iteration, i, predict_type) for i in np.diff([0] + list(sections) + [nrow])] n_preds = [
self.__get_num_preds(start_iteration, num_iteration, i, predict_type)
for i in np.diff([0] + list(sections) + [nrow])
]
n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum() n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
preds = np.empty(sum(n_preds), dtype=np.float64) preds = np.empty(sum(n_preds), dtype=np.float64)
for chunk, (start_idx_pred, end_idx_pred) in zip(np.array_split(mat, sections), for chunk, (start_idx_pred, end_idx_pred) in zip(
zip(n_preds_sections, n_preds_sections[1:])): np.array_split(mat, sections), zip(n_preds_sections, n_preds_sections[1:])
):
# avoid memory consumption by arrays concatenation operations # avoid memory consumption by arrays concatenation operations
self.__inner_predict_np2d( self.__inner_predict_np2d(
mat=chunk, mat=chunk,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type, predict_type=predict_type,
preds=preds[start_idx_pred:end_idx_pred] preds=preds[start_idx_pred:end_idx_pred],
) )
return preds, nrow return preds, nrow
else: else:
...@@ -1321,7 +1341,7 @@ class _InnerPredictor: ...@@ -1321,7 +1341,7 @@ class _InnerPredictor:
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type, predict_type=predict_type,
preds=None preds=None,
) )
def __create_sparse_native( def __create_sparse_native(
...@@ -1333,7 +1353,7 @@ class _InnerPredictor: ...@@ -1333,7 +1353,7 @@ class _InnerPredictor:
out_ptr_data: "ctypes._Pointer", out_ptr_data: "ctypes._Pointer",
indptr_type: int, indptr_type: int,
data_type: int, data_type: int,
is_csr: bool is_csr: bool,
) -> Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]]: ) -> Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]]:
# create numpy array from output arrays # create numpy array from output arrays
data_indices_len = out_shape[0] data_indices_len = out_shape[0]
...@@ -1362,8 +1382,8 @@ class _InnerPredictor: ...@@ -1362,8 +1382,8 @@ class _InnerPredictor:
offset = 0 offset = 0
for cs_indptr in out_indptr_arrays: for cs_indptr in out_indptr_arrays:
matrix_indptr_len = cs_indptr[cs_indptr.shape[0] - 1] matrix_indptr_len = cs_indptr[cs_indptr.shape[0] - 1]
cs_indices = out_indices[offset + cs_indptr[0]:offset + matrix_indptr_len] cs_indices = out_indices[offset + cs_indptr[0] : offset + matrix_indptr_len]
cs_data = out_data[offset + cs_indptr[0]:offset + matrix_indptr_len] cs_data = out_data[offset + cs_indptr[0] : offset + matrix_indptr_len]
offset += matrix_indptr_len offset += matrix_indptr_len
# same shape as input csr or csc matrix except extra column for expected value # same shape as input csr or csc matrix except extra column for expected value
cs_shape = [cs.shape[0], cs.shape[1] + 1] cs_shape = [cs.shape[0], cs.shape[1] + 1]
...@@ -1373,8 +1393,15 @@ class _InnerPredictor: ...@@ -1373,8 +1393,15 @@ class _InnerPredictor:
else: else:
cs_output_matrices.append(scipy.sparse.csc_matrix((cs_data, cs_indices, cs_indptr), cs_shape)) cs_output_matrices.append(scipy.sparse.csc_matrix((cs_data, cs_indices, cs_indptr), cs_shape))
# free the temporary native indptr, indices, and data # free the temporary native indptr, indices, and data
_safe_call(_LIB.LGBM_BoosterFreePredictSparse(out_ptr_indptr, out_ptr_indices, out_ptr_data, _safe_call(
ctypes.c_int(indptr_type), ctypes.c_int(data_type))) _LIB.LGBM_BoosterFreePredictSparse(
out_ptr_indptr,
out_ptr_indices,
out_ptr_data,
ctypes.c_int(indptr_type),
ctypes.c_int(data_type),
)
)
if len(cs_output_matrices) == 1: if len(cs_output_matrices) == 1:
return cs_output_matrices[0] return cs_output_matrices[0]
return cs_output_matrices return cs_output_matrices
...@@ -1385,14 +1412,14 @@ class _InnerPredictor: ...@@ -1385,14 +1412,14 @@ class _InnerPredictor:
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int, predict_type: int,
preds: Optional[np.ndarray] preds: Optional[np.ndarray],
) -> Tuple[np.ndarray, int]: ) -> Tuple[np.ndarray, int]:
nrow = len(csr.indptr) - 1 nrow = len(csr.indptr) - 1
n_preds = self.__get_num_preds( n_preds = self.__get_num_preds(
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
nrow=nrow, nrow=nrow,
predict_type=predict_type predict_type=predict_type,
) )
if preds is None: if preds is None:
preds = np.empty(n_preds, dtype=np.float64) preds = np.empty(n_preds, dtype=np.float64)
...@@ -1406,22 +1433,25 @@ class _InnerPredictor: ...@@ -1406,22 +1433,25 @@ class _InnerPredictor:
assert csr.shape[1] <= _MAX_INT32 assert csr.shape[1] <= _MAX_INT32
csr_indices = csr.indices.astype(np.int32, copy=False) csr_indices = csr.indices.astype(np.int32, copy=False)
_safe_call(_LIB.LGBM_BoosterPredictForCSR( _safe_call(
self._handle, _LIB.LGBM_BoosterPredictForCSR(
ptr_indptr, self._handle,
ctypes.c_int(type_ptr_indptr), ptr_indptr,
csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ctypes.c_int(type_ptr_indptr),
ptr_data, csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(type_ptr_data), ptr_data,
ctypes.c_int64(len(csr.indptr)), ctypes.c_int(type_ptr_data),
ctypes.c_int64(len(csr.data)), ctypes.c_int64(len(csr.indptr)),
ctypes.c_int64(csr.shape[1]), ctypes.c_int64(len(csr.data)),
ctypes.c_int(predict_type), ctypes.c_int64(csr.shape[1]),
ctypes.c_int(start_iteration), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
_c_str(self.pred_parameter), ctypes.c_int(num_iteration),
ctypes.byref(out_num_preds), _c_str(self.pred_parameter),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
)
)
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
raise ValueError("Wrong length for predict results") raise ValueError("Wrong length for predict results")
return preds, nrow return preds, nrow
...@@ -1431,7 +1461,7 @@ class _InnerPredictor: ...@@ -1431,7 +1461,7 @@ class _InnerPredictor:
csr: scipy.sparse.csr_matrix, csr: scipy.sparse.csr_matrix,
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int predict_type: int,
) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]: ) -> Tuple[Union[List[scipy.sparse.csc_matrix], List[scipy.sparse.csr_matrix]], int]:
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr) ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr)
ptr_data, type_ptr_data, _ = _c_float_array(csr.data) ptr_data, type_ptr_data, _ = _c_float_array(csr.data)
...@@ -1449,25 +1479,28 @@ class _InnerPredictor: ...@@ -1449,25 +1479,28 @@ class _InnerPredictor:
else: else:
out_ptr_data = ctypes.POINTER(ctypes.c_double)() out_ptr_data = ctypes.POINTER(ctypes.c_double)()
out_shape = np.empty(2, dtype=np.int64) out_shape = np.empty(2, dtype=np.int64)
_safe_call(_LIB.LGBM_BoosterPredictSparseOutput( _safe_call(
self._handle, _LIB.LGBM_BoosterPredictSparseOutput(
ptr_indptr, self._handle,
ctypes.c_int(type_ptr_indptr), ptr_indptr,
csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ctypes.c_int(type_ptr_indptr),
ptr_data, csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(type_ptr_data), ptr_data,
ctypes.c_int64(len(csr.indptr)), ctypes.c_int(type_ptr_data),
ctypes.c_int64(len(csr.data)), ctypes.c_int64(len(csr.indptr)),
ctypes.c_int64(csr.shape[1]), ctypes.c_int64(len(csr.data)),
ctypes.c_int(predict_type), ctypes.c_int64(csr.shape[1]),
ctypes.c_int(start_iteration), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
_c_str(self.pred_parameter), ctypes.c_int(num_iteration),
ctypes.c_int(matrix_type), _c_str(self.pred_parameter),
out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)), ctypes.c_int(matrix_type),
ctypes.byref(out_ptr_indptr), out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
ctypes.byref(out_ptr_indices), ctypes.byref(out_ptr_indptr),
ctypes.byref(out_ptr_data))) ctypes.byref(out_ptr_indices),
ctypes.byref(out_ptr_data),
)
)
matrices = self.__create_sparse_native( matrices = self.__create_sparse_native(
cs=csr, cs=csr,
out_shape=out_shape, out_shape=out_shape,
...@@ -1476,7 +1509,7 @@ class _InnerPredictor: ...@@ -1476,7 +1509,7 @@ class _InnerPredictor:
out_ptr_data=out_ptr_data, out_ptr_data=out_ptr_data,
indptr_type=type_ptr_indptr, indptr_type=type_ptr_indptr,
data_type=type_ptr_data, data_type=type_ptr_data,
is_csr=True is_csr=True,
) )
nrow = len(csr.indptr) - 1 nrow = len(csr.indptr) - 1
return matrices, nrow return matrices, nrow
...@@ -1486,7 +1519,7 @@ class _InnerPredictor: ...@@ -1486,7 +1519,7 @@ class _InnerPredictor:
csr: scipy.sparse.csr_matrix, csr: scipy.sparse.csr_matrix,
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int predict_type: int,
) -> Tuple[np.ndarray, int]: ) -> Tuple[np.ndarray, int]:
"""Predict for a CSR data.""" """Predict for a CSR data."""
if predict_type == _C_API_PREDICT_CONTRIB: if predict_type == _C_API_PREDICT_CONTRIB:
...@@ -1494,7 +1527,7 @@ class _InnerPredictor: ...@@ -1494,7 +1527,7 @@ class _InnerPredictor:
csr=csr, csr=csr,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
nrow = len(csr.indptr) - 1 nrow = len(csr.indptr) - 1
if nrow > _MAX_INT32: if nrow > _MAX_INT32:
...@@ -1503,15 +1536,16 @@ class _InnerPredictor: ...@@ -1503,15 +1536,16 @@ class _InnerPredictor:
n_preds = [self.__get_num_preds(start_iteration, num_iteration, i, predict_type) for i in np.diff(sections)] n_preds = [self.__get_num_preds(start_iteration, num_iteration, i, predict_type) for i in np.diff(sections)]
n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum() n_preds_sections = np.array([0] + n_preds, dtype=np.intp).cumsum()
preds = np.empty(sum(n_preds), dtype=np.float64) preds = np.empty(sum(n_preds), dtype=np.float64)
for (start_idx, end_idx), (start_idx_pred, end_idx_pred) in zip(zip(sections, sections[1:]), for (start_idx, end_idx), (start_idx_pred, end_idx_pred) in zip(
zip(n_preds_sections, n_preds_sections[1:])): zip(sections, sections[1:]), zip(n_preds_sections, n_preds_sections[1:])
):
# avoid memory consumption by arrays concatenation operations # avoid memory consumption by arrays concatenation operations
self.__inner_predict_csr( self.__inner_predict_csr(
csr=csr[start_idx:end_idx], csr=csr[start_idx:end_idx],
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type, predict_type=predict_type,
preds=preds[start_idx_pred:end_idx_pred] preds=preds[start_idx_pred:end_idx_pred],
) )
return preds, nrow return preds, nrow
else: else:
...@@ -1520,7 +1554,7 @@ class _InnerPredictor: ...@@ -1520,7 +1554,7 @@ class _InnerPredictor:
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type, predict_type=predict_type,
preds=None preds=None,
) )
def __inner_predict_sparse_csc( def __inner_predict_sparse_csc(
...@@ -1528,7 +1562,7 @@ class _InnerPredictor: ...@@ -1528,7 +1562,7 @@ class _InnerPredictor:
csc: scipy.sparse.csc_matrix, csc: scipy.sparse.csc_matrix,
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int predict_type: int,
): ):
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr) ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
ptr_data, type_ptr_data, _ = _c_float_array(csc.data) ptr_data, type_ptr_data, _ = _c_float_array(csc.data)
...@@ -1546,25 +1580,28 @@ class _InnerPredictor: ...@@ -1546,25 +1580,28 @@ class _InnerPredictor:
else: else:
out_ptr_data = ctypes.POINTER(ctypes.c_double)() out_ptr_data = ctypes.POINTER(ctypes.c_double)()
out_shape = np.empty(2, dtype=np.int64) out_shape = np.empty(2, dtype=np.int64)
_safe_call(_LIB.LGBM_BoosterPredictSparseOutput( _safe_call(
self._handle, _LIB.LGBM_BoosterPredictSparseOutput(
ptr_indptr, self._handle,
ctypes.c_int(type_ptr_indptr), ptr_indptr,
csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ctypes.c_int(type_ptr_indptr),
ptr_data, csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(type_ptr_data), ptr_data,
ctypes.c_int64(len(csc.indptr)), ctypes.c_int(type_ptr_data),
ctypes.c_int64(len(csc.data)), ctypes.c_int64(len(csc.indptr)),
ctypes.c_int64(csc.shape[0]), ctypes.c_int64(len(csc.data)),
ctypes.c_int(predict_type), ctypes.c_int64(csc.shape[0]),
ctypes.c_int(start_iteration), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
_c_str(self.pred_parameter), ctypes.c_int(num_iteration),
ctypes.c_int(matrix_type), _c_str(self.pred_parameter),
out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)), ctypes.c_int(matrix_type),
ctypes.byref(out_ptr_indptr), out_shape.ctypes.data_as(ctypes.POINTER(ctypes.c_int64)),
ctypes.byref(out_ptr_indices), ctypes.byref(out_ptr_indptr),
ctypes.byref(out_ptr_data))) ctypes.byref(out_ptr_indices),
ctypes.byref(out_ptr_data),
)
)
matrices = self.__create_sparse_native( matrices = self.__create_sparse_native(
cs=csc, cs=csc,
out_shape=out_shape, out_shape=out_shape,
...@@ -1573,7 +1610,7 @@ class _InnerPredictor: ...@@ -1573,7 +1610,7 @@ class _InnerPredictor:
out_ptr_data=out_ptr_data, out_ptr_data=out_ptr_data,
indptr_type=type_ptr_indptr, indptr_type=type_ptr_indptr,
data_type=type_ptr_data, data_type=type_ptr_data,
is_csr=False is_csr=False,
) )
nrow = csc.shape[0] nrow = csc.shape[0]
return matrices, nrow return matrices, nrow
...@@ -1583,7 +1620,7 @@ class _InnerPredictor: ...@@ -1583,7 +1620,7 @@ class _InnerPredictor:
csc: scipy.sparse.csc_matrix, csc: scipy.sparse.csc_matrix,
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int predict_type: int,
) -> Tuple[np.ndarray, int]: ) -> Tuple[np.ndarray, int]:
"""Predict for a CSC data.""" """Predict for a CSC data."""
nrow = csc.shape[0] nrow = csc.shape[0]
...@@ -1592,20 +1629,20 @@ class _InnerPredictor: ...@@ -1592,20 +1629,20 @@ class _InnerPredictor:
csr=csc.tocsr(), csr=csc.tocsr(),
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
if predict_type == _C_API_PREDICT_CONTRIB: if predict_type == _C_API_PREDICT_CONTRIB:
return self.__inner_predict_sparse_csc( return self.__inner_predict_sparse_csc(
csc=csc, csc=csc,
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
predict_type=predict_type predict_type=predict_type,
) )
n_preds = self.__get_num_preds( n_preds = self.__get_num_preds(
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
nrow=nrow, nrow=nrow,
predict_type=predict_type predict_type=predict_type,
) )
preds = np.empty(n_preds, dtype=np.float64) preds = np.empty(n_preds, dtype=np.float64)
out_num_preds = ctypes.c_int64(0) out_num_preds = ctypes.c_int64(0)
...@@ -1616,32 +1653,35 @@ class _InnerPredictor: ...@@ -1616,32 +1653,35 @@ class _InnerPredictor:
assert csc.shape[0] <= _MAX_INT32 assert csc.shape[0] <= _MAX_INT32
csc_indices = csc.indices.astype(np.int32, copy=False) csc_indices = csc.indices.astype(np.int32, copy=False)
_safe_call(_LIB.LGBM_BoosterPredictForCSC( _safe_call(
self._handle, _LIB.LGBM_BoosterPredictForCSC(
ptr_indptr, self._handle,
ctypes.c_int(type_ptr_indptr), ptr_indptr,
csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ctypes.c_int(type_ptr_indptr),
ptr_data, csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(type_ptr_data), ptr_data,
ctypes.c_int64(len(csc.indptr)), ctypes.c_int(type_ptr_data),
ctypes.c_int64(len(csc.data)), ctypes.c_int64(len(csc.indptr)),
ctypes.c_int64(csc.shape[0]), ctypes.c_int64(len(csc.data)),
ctypes.c_int(predict_type), ctypes.c_int64(csc.shape[0]),
ctypes.c_int(start_iteration), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
_c_str(self.pred_parameter), ctypes.c_int(num_iteration),
ctypes.byref(out_num_preds), _c_str(self.pred_parameter),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
)
)
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
raise ValueError("Wrong length for predict results") raise ValueError("Wrong length for predict results")
return preds, nrow return preds, nrow
def __pred_for_pyarrow_table( def __pred_for_pyarrow_table(
self, self,
table: pa_Table, table: pa_Table,
start_iteration: int, start_iteration: int,
num_iteration: int, num_iteration: int,
predict_type: int predict_type: int,
) -> Tuple[np.ndarray, int]: ) -> Tuple[np.ndarray, int]:
"""Predict for a PyArrow table.""" """Predict for a PyArrow table."""
if not PYARROW_INSTALLED: if not PYARROW_INSTALLED:
...@@ -1656,24 +1696,27 @@ class _InnerPredictor: ...@@ -1656,24 +1696,27 @@ class _InnerPredictor:
start_iteration=start_iteration, start_iteration=start_iteration,
num_iteration=num_iteration, num_iteration=num_iteration,
nrow=table.num_rows, nrow=table.num_rows,
predict_type=predict_type predict_type=predict_type,
) )
preds = np.empty(n_preds, dtype=np.float64) preds = np.empty(n_preds, dtype=np.float64)
out_num_preds = ctypes.c_int64(0) out_num_preds = ctypes.c_int64(0)
# Export Arrow table to C and run prediction # Export Arrow table to C and run prediction
c_array = _export_arrow_to_c(table) c_array = _export_arrow_to_c(table)
_safe_call(_LIB.LGBM_BoosterPredictForArrow( _safe_call(
self._handle, _LIB.LGBM_BoosterPredictForArrow(
ctypes.c_int64(c_array.n_chunks), self._handle,
ctypes.c_void_p(c_array.chunks_ptr), ctypes.c_int64(c_array.n_chunks),
ctypes.c_void_p(c_array.schema_ptr), ctypes.c_void_p(c_array.chunks_ptr),
ctypes.c_int(predict_type), ctypes.c_void_p(c_array.schema_ptr),
ctypes.c_int(start_iteration), ctypes.c_int(predict_type),
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
_c_str(self.pred_parameter), ctypes.c_int(num_iteration),
ctypes.byref(out_num_preds), _c_str(self.pred_parameter),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
)
)
if n_preds != out_num_preds.value: if n_preds != out_num_preds.value:
raise ValueError("Wrong length for predict results") raise ValueError("Wrong length for predict results")
return preds, table.num_rows return preds, table.num_rows
...@@ -1687,9 +1730,12 @@ class _InnerPredictor: ...@@ -1687,9 +1730,12 @@ class _InnerPredictor:
The index of the current iteration. The index of the current iteration.
""" """
out_cur_iter = ctypes.c_int(0) out_cur_iter = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration( _safe_call(
self._handle, _LIB.LGBM_BoosterGetCurrentIteration(
ctypes.byref(out_cur_iter))) self._handle,
ctypes.byref(out_cur_iter),
)
)
return out_cur_iter.value return out_cur_iter.value
...@@ -1704,8 +1750,8 @@ class Dataset: ...@@ -1704,8 +1750,8 @@ class Dataset:
weight: Optional[_LGBM_WeightType] = None, weight: Optional[_LGBM_WeightType] = None,
group: Optional[_LGBM_GroupType] = None, group: Optional[_LGBM_GroupType] = None,
init_score: Optional[_LGBM_InitScoreType] = None, init_score: Optional[_LGBM_InitScoreType] = None,
feature_name: _LGBM_FeatureNameConfiguration = 'auto', feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True, free_raw_data: bool = True,
position: Optional[_LGBM_PositionType] = None, position: Optional[_LGBM_PositionType] = None,
...@@ -1800,20 +1846,22 @@ class Dataset: ...@@ -1800,20 +1846,22 @@ class Dataset:
ptr_data, _, _ = _c_int_array(indices) ptr_data, _, _ = _c_int_array(indices)
actual_sample_cnt = ctypes.c_int32(0) actual_sample_cnt = ctypes.c_int32(0)
_safe_call(_LIB.LGBM_SampleIndices( _safe_call(
ctypes.c_int32(total_nrow), _LIB.LGBM_SampleIndices(
_c_str(param_str), ctypes.c_int32(total_nrow),
ptr_data, _c_str(param_str),
ctypes.byref(actual_sample_cnt), ptr_data,
)) ctypes.byref(actual_sample_cnt),
)
)
assert sample_cnt == actual_sample_cnt.value assert sample_cnt == actual_sample_cnt.value
return indices return indices
def _init_from_ref_dataset( def _init_from_ref_dataset(
self, self,
total_nrow: int, total_nrow: int,
ref_dataset: _DatasetHandle ref_dataset: _DatasetHandle,
) -> 'Dataset': ) -> "Dataset":
"""Create dataset from a reference dataset. """Create dataset from a reference dataset.
Parameters Parameters
...@@ -1829,11 +1877,13 @@ class Dataset: ...@@ -1829,11 +1877,13 @@ class Dataset:
Constructed Dataset object. Constructed Dataset object.
""" """
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_DatasetCreateByReference( _safe_call(
ref_dataset, _LIB.LGBM_DatasetCreateByReference(
ctypes.c_int64(total_nrow), ref_dataset,
ctypes.byref(self._handle), ctypes.c_int64(total_nrow),
)) ctypes.byref(self._handle),
)
)
return self return self
def _init_from_sample( def _init_from_sample(
...@@ -1885,20 +1935,22 @@ class Dataset: ...@@ -1885,20 +1935,22 @@ class Dataset:
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
params_str = _param_dict_to_str(self.get_params()) params_str = _param_dict_to_str(self.get_params())
_safe_call(_LIB.LGBM_DatasetCreateFromSampledColumn( _safe_call(
ctypes.cast(sample_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))), _LIB.LGBM_DatasetCreateFromSampledColumn(
ctypes.cast(indices_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_int32))), ctypes.cast(sample_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))),
ctypes.c_int32(ncol), ctypes.cast(indices_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_int32))),
num_per_col_ptr, ctypes.c_int32(ncol),
ctypes.c_int32(sample_cnt), num_per_col_ptr,
ctypes.c_int32(total_nrow), ctypes.c_int32(sample_cnt),
ctypes.c_int64(total_nrow), ctypes.c_int32(total_nrow),
_c_str(params_str), ctypes.c_int64(total_nrow),
ctypes.byref(self._handle), _c_str(params_str),
)) ctypes.byref(self._handle),
)
)
return self return self
def _push_rows(self, data: np.ndarray) -> 'Dataset': def _push_rows(self, data: np.ndarray) -> "Dataset":
"""Add rows to Dataset. """Add rows to Dataset.
Parameters Parameters
...@@ -1915,14 +1967,16 @@ class Dataset: ...@@ -1915,14 +1967,16 @@ class Dataset:
data = data.reshape(data.size) data = data.reshape(data.size)
data_ptr, data_type, _ = _c_float_array(data) data_ptr, data_type, _ = _c_float_array(data)
_safe_call(_LIB.LGBM_DatasetPushRows( _safe_call(
self._handle, _LIB.LGBM_DatasetPushRows(
data_ptr, self._handle,
data_type, data_ptr,
ctypes.c_int32(nrow), data_type,
ctypes.c_int32(ncol), ctypes.c_int32(nrow),
ctypes.c_int32(self._start_row), ctypes.c_int32(ncol),
)) ctypes.c_int32(self._start_row),
)
)
self._start_row += nrow self._start_row += nrow
return self return self
...@@ -1936,27 +1990,29 @@ class Dataset: ...@@ -1936,27 +1990,29 @@ class Dataset:
""" """
if self.params is not None: if self.params is not None:
# no min_data, nthreads and verbose in this function # no min_data, nthreads and verbose in this function
dataset_params = _ConfigAliases.get("bin_construct_sample_cnt", dataset_params = _ConfigAliases.get(
"categorical_feature", "bin_construct_sample_cnt",
"data_random_seed", "categorical_feature",
"enable_bundle", "data_random_seed",
"feature_pre_filter", "enable_bundle",
"forcedbins_filename", "feature_pre_filter",
"group_column", "forcedbins_filename",
"header", "group_column",
"ignore_column", "header",
"is_enable_sparse", "ignore_column",
"label_column", "is_enable_sparse",
"linear_tree", "label_column",
"max_bin", "linear_tree",
"max_bin_by_feature", "max_bin",
"min_data_in_bin", "max_bin_by_feature",
"pre_partition", "min_data_in_bin",
"precise_float_parser", "pre_partition",
"two_round", "precise_float_parser",
"use_missing", "two_round",
"weight_column", "use_missing",
"zero_as_missing") "weight_column",
"zero_as_missing",
)
return {k: v for k, v in self.params.items() if k in dataset_params} return {k: v for k, v in self.params.items() if k in dataset_params}
else: else:
return {} return {}
...@@ -1974,7 +2030,7 @@ class Dataset: ...@@ -1974,7 +2030,7 @@ class Dataset:
self, self,
predictor: Optional[_InnerPredictor], predictor: Optional[_InnerPredictor],
data: _LGBM_TrainDataType, data: _LGBM_TrainDataType,
used_indices: Optional[Union[List[int], np.ndarray]] used_indices: Optional[Union[List[int], np.ndarray]],
) -> "Dataset": ) -> "Dataset":
data_has_header = False data_has_header = False
if isinstance(data, (str, Path)) and self.params is not None: if isinstance(data, (str, Path)) and self.params is not None:
...@@ -1985,7 +2041,7 @@ class Dataset: ...@@ -1985,7 +2041,7 @@ class Dataset:
init_score: Union[np.ndarray, scipy.sparse.spmatrix] = predictor.predict( init_score: Union[np.ndarray, scipy.sparse.spmatrix] = predictor.predict(
data=data, data=data,
raw_score=True, raw_score=True,
data_has_header=data_has_header data_has_header=data_has_header,
) )
init_score = init_score.ravel() init_score = init_score.ravel()
if used_indices is not None: if used_indices is not None:
...@@ -1995,7 +2051,9 @@ class Dataset: ...@@ -1995,7 +2051,9 @@ class Dataset:
assert num_data == len(used_indices) assert num_data == len(used_indices)
for i in range(len(used_indices)): for i in range(len(used_indices)):
for j in range(predictor.num_class): for j in range(predictor.num_class):
sub_init_score[i * predictor.num_class + j] = init_score[used_indices[i] * predictor.num_class + j] sub_init_score[i * predictor.num_class + j] = init_score[
used_indices[i] * predictor.num_class + j
]
init_score = sub_init_score init_score = sub_init_score
if predictor.num_class > 1: if predictor.num_class > 1:
# need to regroup init_score # need to regroup init_score
...@@ -2023,7 +2081,7 @@ class Dataset: ...@@ -2023,7 +2081,7 @@ class Dataset:
feature_name: _LGBM_FeatureNameConfiguration, feature_name: _LGBM_FeatureNameConfiguration,
categorical_feature: _LGBM_CategoricalFeatureConfiguration, categorical_feature: _LGBM_CategoricalFeatureConfiguration,
params: Optional[Dict[str, Any]], params: Optional[Dict[str, Any]],
position: Optional[_LGBM_PositionType] position: Optional[_LGBM_PositionType],
) -> "Dataset": ) -> "Dataset":
if data is None: if data is None:
self._handle = None self._handle = None
...@@ -2036,7 +2094,7 @@ class Dataset: ...@@ -2036,7 +2094,7 @@ class Dataset:
data=data, data=data,
feature_name=feature_name, feature_name=feature_name,
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
pandas_categorical=self.pandas_categorical pandas_categorical=self.pandas_categorical,
) )
# process for args # process for args
...@@ -2044,8 +2102,10 @@ class Dataset: ...@@ -2044,8 +2102,10 @@ class Dataset:
args_names = inspect.signature(self.__class__._lazy_init).parameters.keys() args_names = inspect.signature(self.__class__._lazy_init).parameters.keys()
for key in params.keys(): for key in params.keys():
if key in args_names: if key in args_names:
_log_warning(f'{key} keyword has been found in `params` and will be ignored.\n' _log_warning(
f'Please use {key} argument of the Dataset constructor to pass this parameter.') f"{key} keyword has been found in `params` and will be ignored.\n"
f"Please use {key} argument of the Dataset constructor to pass this parameter."
)
# get categorical features # get categorical features
if isinstance(categorical_feature, list): if isinstance(categorical_feature, list):
categorical_indices = set() categorical_indices = set()
...@@ -2064,9 +2124,9 @@ class Dataset: ...@@ -2064,9 +2124,9 @@ class Dataset:
if cat_alias in params: if cat_alias in params:
# If the params[cat_alias] is equal to categorical_indices, do not report the warning. # If the params[cat_alias] is equal to categorical_indices, do not report the warning.
if not (isinstance(params[cat_alias], list) and set(params[cat_alias]) == categorical_indices): if not (isinstance(params[cat_alias], list) and set(params[cat_alias]) == categorical_indices):
_log_warning(f'{cat_alias} in param dict is overridden.') _log_warning(f"{cat_alias} in param dict is overridden.")
params.pop(cat_alias, None) params.pop(cat_alias, None)
params['categorical_column'] = sorted(categorical_indices) params["categorical_column"] = sorted(categorical_indices)
params_str = _param_dict_to_str(params) params_str = _param_dict_to_str(params)
self.params = params self.params = params
...@@ -2075,15 +2135,18 @@ class Dataset: ...@@ -2075,15 +2135,18 @@ class Dataset:
if isinstance(reference, Dataset): if isinstance(reference, Dataset):
ref_dataset = reference.construct()._handle ref_dataset = reference.construct()._handle
elif reference is not None: elif reference is not None:
raise TypeError('Reference dataset should be None or dataset instance') raise TypeError("Reference dataset should be None or dataset instance")
# start construct data # start construct data
if isinstance(data, (str, Path)): if isinstance(data, (str, Path)):
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_DatasetCreateFromFile( _safe_call(
_c_str(str(data)), _LIB.LGBM_DatasetCreateFromFile(
_c_str(params_str), _c_str(str(data)),
ref_dataset, _c_str(params_str),
ctypes.byref(self._handle))) ref_dataset,
ctypes.byref(self._handle),
)
)
elif isinstance(data, scipy.sparse.csr_matrix): elif isinstance(data, scipy.sparse.csr_matrix):
self.__init_from_csr(data, params_str, ref_dataset) self.__init_from_csr(data, params_str, ref_dataset)
elif isinstance(data, scipy.sparse.csc_matrix): elif isinstance(data, scipy.sparse.csc_matrix):
...@@ -2099,7 +2162,7 @@ class Dataset: ...@@ -2099,7 +2162,7 @@ class Dataset:
elif _is_list_of_sequences(data): elif _is_list_of_sequences(data):
self.__init_from_seqs(data, ref_dataset) self.__init_from_seqs(data, ref_dataset)
else: else:
raise TypeError('Data list can only be of ndarray or Sequence') raise TypeError("Data list can only be of ndarray or Sequence")
elif isinstance(data, Sequence): elif isinstance(data, Sequence):
self.__init_from_seqs([data], ref_dataset) self.__init_from_seqs([data], ref_dataset)
elif isinstance(data, dt_DataTable): elif isinstance(data, dt_DataTable):
...@@ -2109,7 +2172,7 @@ class Dataset: ...@@ -2109,7 +2172,7 @@ class Dataset:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
self.__init_from_csr(csr, params_str, ref_dataset) self.__init_from_csr(csr, params_str, ref_dataset)
except BaseException as err: except BaseException as err:
raise TypeError(f'Cannot initialize Dataset from {type(data).__name__}') from err raise TypeError(f"Cannot initialize Dataset from {type(data).__name__}") from err
if label is not None: if label is not None:
self.set_label(label) self.set_label(label)
if self.get_label() is None: if self.get_label() is None:
...@@ -2123,15 +2186,11 @@ class Dataset: ...@@ -2123,15 +2186,11 @@ class Dataset:
if isinstance(predictor, _InnerPredictor): if isinstance(predictor, _InnerPredictor):
if self._predictor is None and init_score is not None: if self._predictor is None and init_score is not None:
_log_warning("The init_score will be overridden by the prediction of init_model.") _log_warning("The init_score will be overridden by the prediction of init_model.")
self._set_init_score_by_predictor( self._set_init_score_by_predictor(predictor=predictor, data=data, used_indices=None)
predictor=predictor,
data=data,
used_indices=None
)
elif init_score is not None: elif init_score is not None:
self.set_init_score(init_score) self.set_init_score(init_score)
elif predictor is not None: elif predictor is not None:
raise TypeError(f'Wrong predictor type {type(predictor).__name__}') raise TypeError(f"Wrong predictor type {type(predictor).__name__}")
# set feature names # set feature names
return self.set_feature_name(feature_name) return self.set_feature_name(feature_name)
...@@ -2148,7 +2207,7 @@ class Dataset: ...@@ -2148,7 +2207,7 @@ class Dataset:
seq = seqs[seq_id] seq = seqs[seq_id]
id_in_seq = row_id - offset id_in_seq = row_id - offset
row = seq[id_in_seq] row = seq[id_in_seq]
yield row if row.flags['OWNDATA'] else row.copy() yield row if row.flags["OWNDATA"] else row.copy()
def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]: def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Sample data from seqs. """Sample data from seqs.
...@@ -2181,7 +2240,7 @@ class Dataset: ...@@ -2181,7 +2240,7 @@ class Dataset:
def __init_from_seqs( def __init_from_seqs(
self, self,
seqs: List[Sequence], seqs: List[Sequence],
ref_dataset: Optional[_DatasetHandle] ref_dataset: Optional[_DatasetHandle],
) -> "Dataset": ) -> "Dataset":
""" """
Initialize data from list of Sequence objects. Initialize data from list of Sequence objects.
...@@ -2205,7 +2264,7 @@ class Dataset: ...@@ -2205,7 +2264,7 @@ class Dataset:
for seq in seqs: for seq in seqs:
nrow = len(seq) nrow = len(seq)
batch_size = getattr(seq, 'batch_size', None) or Sequence.batch_size batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size
for start in range(0, nrow, batch_size): for start in range(0, nrow, batch_size):
end = min(start + batch_size, nrow) end = min(start + batch_size, nrow)
self._push_rows(seq[start:end]) self._push_rows(seq[start:end])
...@@ -2215,11 +2274,11 @@ class Dataset: ...@@ -2215,11 +2274,11 @@ class Dataset:
self, self,
mat: np.ndarray, mat: np.ndarray,
params_str: str, params_str: str,
ref_dataset: Optional[_DatasetHandle] ref_dataset: Optional[_DatasetHandle],
) -> "Dataset": ) -> "Dataset":
"""Initialize data from a 2-D numpy matrix.""" """Initialize data from a 2-D numpy matrix."""
if len(mat.shape) != 2: if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional') raise ValueError("Input numpy.ndarray must be 2 dimensional")
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
if mat.dtype == np.float32 or mat.dtype == np.float64: if mat.dtype == np.float32 or mat.dtype == np.float64:
...@@ -2228,22 +2287,25 @@ class Dataset: ...@@ -2228,22 +2287,25 @@ class Dataset:
data = np.array(mat.reshape(mat.size), dtype=np.float32) data = np.array(mat.reshape(mat.size), dtype=np.float32)
ptr_data, type_ptr_data, _ = _c_float_array(data) ptr_data, type_ptr_data, _ = _c_float_array(data)
_safe_call(_LIB.LGBM_DatasetCreateFromMat( _safe_call(
ptr_data, _LIB.LGBM_DatasetCreateFromMat(
ctypes.c_int(type_ptr_data), ptr_data,
ctypes.c_int32(mat.shape[0]), ctypes.c_int(type_ptr_data),
ctypes.c_int32(mat.shape[1]), ctypes.c_int32(mat.shape[0]),
ctypes.c_int(_C_API_IS_ROW_MAJOR), ctypes.c_int32(mat.shape[1]),
_c_str(params_str), ctypes.c_int(_C_API_IS_ROW_MAJOR),
ref_dataset, _c_str(params_str),
ctypes.byref(self._handle))) ref_dataset,
ctypes.byref(self._handle),
)
)
return self return self
def __init_from_list_np2d( def __init_from_list_np2d(
self, self,
mats: List[np.ndarray], mats: List[np.ndarray],
params_str: str, params_str: str,
ref_dataset: Optional[_DatasetHandle] ref_dataset: Optional[_DatasetHandle],
) -> "Dataset": ) -> "Dataset":
"""Initialize data from a list of 2-D numpy matrices.""" """Initialize data from a list of 2-D numpy matrices."""
ncol = mats[0].shape[1] ncol = mats[0].shape[1]
...@@ -2259,10 +2321,10 @@ class Dataset: ...@@ -2259,10 +2321,10 @@ class Dataset:
for i, mat in enumerate(mats): for i, mat in enumerate(mats):
if len(mat.shape) != 2: if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional') raise ValueError("Input numpy.ndarray must be 2 dimensional")
if mat.shape[1] != ncol: if mat.shape[1] != ncol:
raise ValueError('Input arrays must have same number of columns') raise ValueError("Input arrays must have same number of columns")
nrow[i] = mat.shape[0] nrow[i] = mat.shape[0]
...@@ -2273,33 +2335,36 @@ class Dataset: ...@@ -2273,33 +2335,36 @@ class Dataset:
chunk_ptr_data, chunk_type_ptr_data, holder = _c_float_array(mats[i]) chunk_ptr_data, chunk_type_ptr_data, holder = _c_float_array(mats[i])
if type_ptr_data != -1 and chunk_type_ptr_data != type_ptr_data: if type_ptr_data != -1 and chunk_type_ptr_data != type_ptr_data:
raise ValueError('Input chunks must have same type') raise ValueError("Input chunks must have same type")
ptr_data[i] = chunk_ptr_data ptr_data[i] = chunk_ptr_data
type_ptr_data = chunk_type_ptr_data type_ptr_data = chunk_type_ptr_data
holders.append(holder) holders.append(holder)
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_DatasetCreateFromMats( _safe_call(
ctypes.c_int32(len(mats)), _LIB.LGBM_DatasetCreateFromMats(
ctypes.cast(ptr_data, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))), ctypes.c_int32(len(mats)),
ctypes.c_int(type_ptr_data), ctypes.cast(ptr_data, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))),
nrow.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ctypes.c_int(type_ptr_data),
ctypes.c_int32(ncol), nrow.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(_C_API_IS_ROW_MAJOR), ctypes.c_int32(ncol),
_c_str(params_str), ctypes.c_int(_C_API_IS_ROW_MAJOR),
ref_dataset, _c_str(params_str),
ctypes.byref(self._handle))) ref_dataset,
ctypes.byref(self._handle),
)
)
return self return self
def __init_from_csr( def __init_from_csr(
self, self,
csr: scipy.sparse.csr_matrix, csr: scipy.sparse.csr_matrix,
params_str: str, params_str: str,
ref_dataset: Optional[_DatasetHandle] ref_dataset: Optional[_DatasetHandle],
) -> "Dataset": ) -> "Dataset":
"""Initialize data from a CSR matrix.""" """Initialize data from a CSR matrix."""
if len(csr.indices) != len(csr.data): if len(csr.indices) != len(csr.data):
raise ValueError(f'Length mismatch: {len(csr.indices)} vs {len(csr.data)}') raise ValueError(f"Length mismatch: {len(csr.indices)} vs {len(csr.data)}")
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr) ptr_indptr, type_ptr_indptr, __ = _c_int_array(csr.indptr)
...@@ -2308,29 +2373,32 @@ class Dataset: ...@@ -2308,29 +2373,32 @@ class Dataset:
assert csr.shape[1] <= _MAX_INT32 assert csr.shape[1] <= _MAX_INT32
csr_indices = csr.indices.astype(np.int32, copy=False) csr_indices = csr.indices.astype(np.int32, copy=False)
_safe_call(_LIB.LGBM_DatasetCreateFromCSR( _safe_call(
ptr_indptr, _LIB.LGBM_DatasetCreateFromCSR(
ctypes.c_int(type_ptr_indptr), ptr_indptr,
csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ctypes.c_int(type_ptr_indptr),
ptr_data, csr_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(type_ptr_data), ptr_data,
ctypes.c_int64(len(csr.indptr)), ctypes.c_int(type_ptr_data),
ctypes.c_int64(len(csr.data)), ctypes.c_int64(len(csr.indptr)),
ctypes.c_int64(csr.shape[1]), ctypes.c_int64(len(csr.data)),
_c_str(params_str), ctypes.c_int64(csr.shape[1]),
ref_dataset, _c_str(params_str),
ctypes.byref(self._handle))) ref_dataset,
ctypes.byref(self._handle),
)
)
return self return self
def __init_from_csc( def __init_from_csc(
self, self,
csc: scipy.sparse.csc_matrix, csc: scipy.sparse.csc_matrix,
params_str: str, params_str: str,
ref_dataset: Optional[_DatasetHandle] ref_dataset: Optional[_DatasetHandle],
) -> "Dataset": ) -> "Dataset":
"""Initialize data from a CSC matrix.""" """Initialize data from a CSC matrix."""
if len(csc.indices) != len(csc.data): if len(csc.indices) != len(csc.data):
raise ValueError(f'Length mismatch: {len(csc.indices)} vs {len(csc.data)}') raise ValueError(f"Length mismatch: {len(csc.indices)} vs {len(csc.data)}")
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr) ptr_indptr, type_ptr_indptr, __ = _c_int_array(csc.indptr)
...@@ -2339,25 +2407,28 @@ class Dataset: ...@@ -2339,25 +2407,28 @@ class Dataset:
assert csc.shape[0] <= _MAX_INT32 assert csc.shape[0] <= _MAX_INT32
csc_indices = csc.indices.astype(np.int32, copy=False) csc_indices = csc.indices.astype(np.int32, copy=False)
_safe_call(_LIB.LGBM_DatasetCreateFromCSC( _safe_call(
ptr_indptr, _LIB.LGBM_DatasetCreateFromCSC(
ctypes.c_int(type_ptr_indptr), ptr_indptr,
csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), ctypes.c_int(type_ptr_indptr),
ptr_data, csc_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
ctypes.c_int(type_ptr_data), ptr_data,
ctypes.c_int64(len(csc.indptr)), ctypes.c_int(type_ptr_data),
ctypes.c_int64(len(csc.data)), ctypes.c_int64(len(csc.indptr)),
ctypes.c_int64(csc.shape[0]), ctypes.c_int64(len(csc.data)),
_c_str(params_str), ctypes.c_int64(csc.shape[0]),
ref_dataset, _c_str(params_str),
ctypes.byref(self._handle))) ref_dataset,
ctypes.byref(self._handle),
)
)
return self return self
def __init_from_pyarrow_table( def __init_from_pyarrow_table(
self, self,
table: pa_Table, table: pa_Table,
params_str: str, params_str: str,
ref_dataset: Optional[_DatasetHandle] ref_dataset: Optional[_DatasetHandle],
) -> "Dataset": ) -> "Dataset":
"""Initialize data from a PyArrow table.""" """Initialize data from a PyArrow table."""
if not PYARROW_INSTALLED: if not PYARROW_INSTALLED:
...@@ -2370,20 +2441,23 @@ class Dataset: ...@@ -2370,20 +2441,23 @@ class Dataset:
# Export Arrow table to C # Export Arrow table to C
c_array = _export_arrow_to_c(table) c_array = _export_arrow_to_c(table)
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_DatasetCreateFromArrow( _safe_call(
ctypes.c_int64(c_array.n_chunks), _LIB.LGBM_DatasetCreateFromArrow(
ctypes.c_void_p(c_array.chunks_ptr), ctypes.c_int64(c_array.n_chunks),
ctypes.c_void_p(c_array.schema_ptr), ctypes.c_void_p(c_array.chunks_ptr),
_c_str(params_str), ctypes.c_void_p(c_array.schema_ptr),
ref_dataset, _c_str(params_str),
ctypes.byref(self._handle))) ref_dataset,
ctypes.byref(self._handle),
)
)
return self return self
@staticmethod @staticmethod
def _compare_params_for_warning( def _compare_params_for_warning(
params: Dict[str, Any], params: Dict[str, Any],
other_params: Dict[str, Any], other_params: Dict[str, Any],
ignore_keys: Set[str] ignore_keys: Set[str],
) -> bool: ) -> bool:
"""Compare two dictionaries with params ignoring some keys. """Compare two dictionaries with params ignoring some keys.
...@@ -2429,32 +2503,45 @@ class Dataset: ...@@ -2429,32 +2503,45 @@ class Dataset:
if not self._compare_params_for_warning( if not self._compare_params_for_warning(
params=params, params=params,
other_params=reference_params, other_params=reference_params,
ignore_keys=_ConfigAliases.get("categorical_feature") ignore_keys=_ConfigAliases.get("categorical_feature"),
): ):
_log_warning('Overriding the parameters from Reference Dataset.') _log_warning("Overriding the parameters from Reference Dataset.")
self._update_params(reference_params) self._update_params(reference_params)
if self.used_indices is None: if self.used_indices is None:
# create valid # create valid
self._lazy_init(data=self.data, label=self.label, reference=self.reference, self._lazy_init(
weight=self.weight, group=self.group, position=self.position, data=self.data,
init_score=self.init_score, predictor=self._predictor, label=self.label,
feature_name=self.feature_name, categorical_feature='auto', params=self.params) reference=self.reference,
weight=self.weight,
group=self.group,
position=self.position,
init_score=self.init_score,
predictor=self._predictor,
feature_name=self.feature_name,
categorical_feature="auto",
params=self.params,
)
else: else:
# construct subset # construct subset
used_indices = _list_to_1d_numpy(self.used_indices, dtype=np.int32, name='used_indices') used_indices = _list_to_1d_numpy(self.used_indices, dtype=np.int32, name="used_indices")
assert used_indices.flags.c_contiguous assert used_indices.flags.c_contiguous
if self.reference.group is not None: if self.reference.group is not None:
group_info = np.array(self.reference.group).astype(np.int32, copy=False) group_info = np.array(self.reference.group).astype(np.int32, copy=False)
_, self.group = np.unique(np.repeat(range(len(group_info)), repeats=group_info)[self.used_indices], _, self.group = np.unique(
return_counts=True) np.repeat(range(len(group_info)), repeats=group_info)[self.used_indices], return_counts=True
)
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
params_str = _param_dict_to_str(self.params) params_str = _param_dict_to_str(self.params)
_safe_call(_LIB.LGBM_DatasetGetSubset( _safe_call(
self.reference.construct()._handle, _LIB.LGBM_DatasetGetSubset(
used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), self.reference.construct()._handle,
ctypes.c_int32(used_indices.shape[0]), used_indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
_c_str(params_str), ctypes.c_int32(used_indices.shape[0]),
ctypes.byref(self._handle))) _c_str(params_str),
ctypes.byref(self._handle),
)
)
if not self.free_raw_data: if not self.free_raw_data:
self.get_data() self.get_data()
if self.group is not None: if self.group is not None:
...@@ -2463,20 +2550,29 @@ class Dataset: ...@@ -2463,20 +2550,29 @@ class Dataset:
self.set_position(self.position) self.set_position(self.position)
if self.get_label() is None: if self.get_label() is None:
raise ValueError("Label should not be None.") raise ValueError("Label should not be None.")
if isinstance(self._predictor, _InnerPredictor) and self._predictor is not self.reference._predictor: if (
isinstance(self._predictor, _InnerPredictor)
and self._predictor is not self.reference._predictor
):
self.get_data() self.get_data()
self._set_init_score_by_predictor( self._set_init_score_by_predictor(
predictor=self._predictor, predictor=self._predictor, data=self.data, used_indices=used_indices
data=self.data,
used_indices=used_indices
) )
else: else:
# create train # create train
self._lazy_init(data=self.data, label=self.label, reference=None, self._lazy_init(
weight=self.weight, group=self.group, data=self.data,
init_score=self.init_score, predictor=self._predictor, label=self.label,
feature_name=self.feature_name, categorical_feature=self.categorical_feature, reference=None,
params=self.params, position=self.position) weight=self.weight,
group=self.group,
init_score=self.init_score,
predictor=self._predictor,
feature_name=self.feature_name,
categorical_feature=self.categorical_feature,
params=self.params,
position=self.position,
)
if self.free_raw_data: if self.free_raw_data:
self.data = None self.data = None
self.feature_name = self.get_feature_name() self.feature_name = self.get_feature_name()
...@@ -2490,7 +2586,7 @@ class Dataset: ...@@ -2490,7 +2586,7 @@ class Dataset:
group: Optional[_LGBM_GroupType] = None, group: Optional[_LGBM_GroupType] = None,
init_score: Optional[_LGBM_InitScoreType] = None, init_score: Optional[_LGBM_InitScoreType] = None,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
position: Optional[_LGBM_PositionType] = None position: Optional[_LGBM_PositionType] = None,
) -> "Dataset": ) -> "Dataset":
"""Create validation data align with current Dataset. """Create validation data align with current Dataset.
...@@ -2521,9 +2617,17 @@ class Dataset: ...@@ -2521,9 +2617,17 @@ class Dataset:
valid : Dataset valid : Dataset
Validation Dataset with reference to self. Validation Dataset with reference to self.
""" """
ret = Dataset(data, label=label, reference=self, ret = Dataset(
weight=weight, group=group, position=position, init_score=init_score, data,
params=params, free_raw_data=self.free_raw_data) label=label,
reference=self,
weight=weight,
group=group,
position=position,
init_score=init_score,
params=params,
free_raw_data=self.free_raw_data,
)
ret._predictor = self._predictor ret._predictor = self._predictor
ret.pandas_categorical = self.pandas_categorical ret.pandas_categorical = self.pandas_categorical
return ret return ret
...@@ -2531,7 +2635,7 @@ class Dataset: ...@@ -2531,7 +2635,7 @@ class Dataset:
def subset( def subset(
self, self,
used_indices: List[int], used_indices: List[int],
params: Optional[Dict[str, Any]] = None params: Optional[Dict[str, Any]] = None,
) -> "Dataset": ) -> "Dataset":
"""Get subset of current Dataset. """Get subset of current Dataset.
...@@ -2549,9 +2653,14 @@ class Dataset: ...@@ -2549,9 +2653,14 @@ class Dataset:
""" """
if params is None: if params is None:
params = self.params params = self.params
ret = Dataset(None, reference=self, feature_name=self.feature_name, ret = Dataset(
categorical_feature=self.categorical_feature, params=params, None,
free_raw_data=self.free_raw_data) reference=self,
feature_name=self.feature_name,
categorical_feature=self.categorical_feature,
params=params,
free_raw_data=self.free_raw_data,
)
ret._predictor = self._predictor ret._predictor = self._predictor
ret.pandas_categorical = self.pandas_categorical ret.pandas_categorical = self.pandas_categorical
ret.used_indices = sorted(used_indices) ret.used_indices = sorted(used_indices)
...@@ -2575,9 +2684,12 @@ class Dataset: ...@@ -2575,9 +2684,12 @@ class Dataset:
self : Dataset self : Dataset
Returns self. Returns self.
""" """
_safe_call(_LIB.LGBM_DatasetSaveBinary( _safe_call(
self.construct()._handle, _LIB.LGBM_DatasetSaveBinary(
_c_str(str(filename)))) self.construct()._handle,
_c_str(str(filename)),
)
)
return self return self
def _update_params(self, params: Optional[Dict[str, Any]]) -> "Dataset": def _update_params(self, params: Optional[Dict[str, Any]]) -> "Dataset":
...@@ -2597,14 +2709,15 @@ class Dataset: ...@@ -2597,14 +2709,15 @@ class Dataset:
elif params is not None: elif params is not None:
ret = _LIB.LGBM_DatasetUpdateParamChecking( ret = _LIB.LGBM_DatasetUpdateParamChecking(
_c_str(_param_dict_to_str(self.params)), _c_str(_param_dict_to_str(self.params)),
_c_str(_param_dict_to_str(params))) _c_str(_param_dict_to_str(params)),
)
if ret != 0: if ret != 0:
# could be updated if data is not freed # could be updated if data is not freed
if self.data is not None: if self.data is not None:
update() update()
self._free_handle() self._free_handle()
else: else:
raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8')) raise LightGBMError(_LIB.LGBM_GetLastError().decode("utf-8"))
return self return self
def _reverse_update_params(self) -> "Dataset": def _reverse_update_params(self) -> "Dataset":
...@@ -2616,7 +2729,7 @@ class Dataset: ...@@ -2616,7 +2729,7 @@ class Dataset:
def set_field( def set_field(
self, self,
field_name: str, field_name: str,
data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Table, pa_Array, pa_ChunkedArray]] data: Optional[_LGBM_SetFieldType],
) -> "Dataset": ) -> "Dataset":
"""Set property into the Dataset. """Set property into the Dataset.
...@@ -2636,12 +2749,15 @@ class Dataset: ...@@ -2636,12 +2749,15 @@ class Dataset:
raise Exception(f"Cannot set {field_name} before construct dataset") raise Exception(f"Cannot set {field_name} before construct dataset")
if data is None: if data is None:
# set to None # set to None
_safe_call(_LIB.LGBM_DatasetSetField( _safe_call(
self._handle, _LIB.LGBM_DatasetSetField(
_c_str(field_name), self._handle,
None, _c_str(field_name),
ctypes.c_int(0), None,
ctypes.c_int(_FIELD_TYPE_MAPPER[field_name]))) ctypes.c_int(0),
ctypes.c_int(_FIELD_TYPE_MAPPER[field_name]),
)
)
return self return self
# If the data is a arrow data, we can just pass it to C # If the data is a arrow data, we can just pass it to C
...@@ -2651,36 +2767,42 @@ class Dataset: ...@@ -2651,36 +2767,42 @@ class Dataset:
if _is_pyarrow_table(data): if _is_pyarrow_table(data):
if field_name != "init_score": if field_name != "init_score":
raise ValueError(f"pyarrow tables are not supported for field '{field_name}'") raise ValueError(f"pyarrow tables are not supported for field '{field_name}'")
data = pa_chunked_array([ data = pa_chunked_array(
chunk for array in data.columns for chunk in array.chunks # type: ignore [
]) chunk
for array in data.columns # type: ignore
for chunk in array.chunks
]
)
c_array = _export_arrow_to_c(data) c_array = _export_arrow_to_c(data)
_safe_call(_LIB.LGBM_DatasetSetFieldFromArrow( _safe_call(
self._handle, _LIB.LGBM_DatasetSetFieldFromArrow(
_c_str(field_name), self._handle,
ctypes.c_int64(c_array.n_chunks), _c_str(field_name),
ctypes.c_void_p(c_array.chunks_ptr), ctypes.c_int64(c_array.n_chunks),
ctypes.c_void_p(c_array.schema_ptr), ctypes.c_void_p(c_array.chunks_ptr),
)) ctypes.c_void_p(c_array.schema_ptr),
)
)
self.version += 1 self.version += 1
return self return self
dtype: "np.typing.DTypeLike" dtype: "np.typing.DTypeLike"
if field_name == 'init_score': if field_name == "init_score":
dtype = np.float64 dtype = np.float64
if _is_1d_collection(data): if _is_1d_collection(data):
data = _list_to_1d_numpy(data, dtype=dtype, name=field_name) data = _list_to_1d_numpy(data, dtype=dtype, name=field_name)
elif _is_2d_collection(data): elif _is_2d_collection(data):
data = _data_to_2d_numpy(data, dtype=dtype, name=field_name) data = _data_to_2d_numpy(data, dtype=dtype, name=field_name)
data = data.ravel(order='F') data = data.ravel(order="F")
else: else:
raise TypeError( raise TypeError(
'init_score must be list, numpy 1-D array or pandas Series.\n' "init_score must be list, numpy 1-D array or pandas Series.\n"
'In multiclass classification init_score can also be a list of lists, numpy 2-D array or pandas DataFrame.' "In multiclass classification init_score can also be a list of lists, numpy 2-D array or pandas DataFrame."
) )
else: else:
if field_name in {'group', 'position'}: if field_name in {"group", "position"}:
dtype = np.int32 dtype = np.int32
else: else:
dtype = np.float32 dtype = np.float32
...@@ -2695,12 +2817,15 @@ class Dataset: ...@@ -2695,12 +2817,15 @@ class Dataset:
raise TypeError(f"Expected np.float32/64 or np.int32, met type({data.dtype})") raise TypeError(f"Expected np.float32/64 or np.int32, met type({data.dtype})")
if type_data != _FIELD_TYPE_MAPPER[field_name]: if type_data != _FIELD_TYPE_MAPPER[field_name]:
raise TypeError("Input type error for set_field") raise TypeError("Input type error for set_field")
_safe_call(_LIB.LGBM_DatasetSetField( _safe_call(
self._handle, _LIB.LGBM_DatasetSetField(
_c_str(field_name), self._handle,
ptr_data, _c_str(field_name),
ctypes.c_int(len(data)), ptr_data,
ctypes.c_int(type_data))) ctypes.c_int(len(data)),
ctypes.c_int(type_data),
)
)
self.version += 1 self.version += 1
return self return self
...@@ -2728,12 +2853,15 @@ class Dataset: ...@@ -2728,12 +2853,15 @@ class Dataset:
tmp_out_len = ctypes.c_int(0) tmp_out_len = ctypes.c_int(0)
out_type = ctypes.c_int(0) out_type = ctypes.c_int(0)
ret = ctypes.POINTER(ctypes.c_void_p)() ret = ctypes.POINTER(ctypes.c_void_p)()
_safe_call(_LIB.LGBM_DatasetGetField( _safe_call(
self._handle, _LIB.LGBM_DatasetGetField(
_c_str(field_name), self._handle,
ctypes.byref(tmp_out_len), _c_str(field_name),
ctypes.byref(ret), ctypes.byref(tmp_out_len),
ctypes.byref(out_type))) ctypes.byref(ret),
ctypes.byref(out_type),
)
)
if out_type.value != _FIELD_TYPE_MAPPER[field_name]: if out_type.value != _FIELD_TYPE_MAPPER[field_name]:
raise TypeError("Return type error for get_field") raise TypeError("Return type error for get_field")
if tmp_out_len.value == 0: if tmp_out_len.value == 0:
...@@ -2741,30 +2869,30 @@ class Dataset: ...@@ -2741,30 +2869,30 @@ class Dataset:
if out_type.value == _C_API_DTYPE_INT32: if out_type.value == _C_API_DTYPE_INT32:
arr = _cint32_array_to_numpy( arr = _cint32_array_to_numpy(
cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)), cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_int32)),
length=tmp_out_len.value length=tmp_out_len.value,
) )
elif out_type.value == _C_API_DTYPE_FLOAT32: elif out_type.value == _C_API_DTYPE_FLOAT32:
arr = _cfloat32_array_to_numpy( arr = _cfloat32_array_to_numpy(
cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)), cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_float)),
length=tmp_out_len.value length=tmp_out_len.value,
) )
elif out_type.value == _C_API_DTYPE_FLOAT64: elif out_type.value == _C_API_DTYPE_FLOAT64:
arr = _cfloat64_array_to_numpy( arr = _cfloat64_array_to_numpy(
cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)), cptr=ctypes.cast(ret, ctypes.POINTER(ctypes.c_double)),
length=tmp_out_len.value length=tmp_out_len.value,
) )
else: else:
raise TypeError("Unknown type") raise TypeError("Unknown type")
if field_name == 'init_score': if field_name == "init_score":
num_data = self.num_data() num_data = self.num_data()
num_classes = arr.size // num_data num_classes = arr.size // num_data
if num_classes > 1: if num_classes > 1:
arr = arr.reshape((num_data, num_classes), order='F') arr = arr.reshape((num_data, num_classes), order="F")
return arr return arr
def set_categorical_feature( def set_categorical_feature(
self, self,
categorical_feature: _LGBM_CategoricalFeatureConfiguration categorical_feature: _LGBM_CategoricalFeatureConfiguration,
) -> "Dataset": ) -> "Dataset":
"""Set categorical features. """Set categorical features.
...@@ -2784,21 +2912,25 @@ class Dataset: ...@@ -2784,21 +2912,25 @@ class Dataset:
if self.categorical_feature is None: if self.categorical_feature is None:
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
return self._free_handle() return self._free_handle()
elif categorical_feature == 'auto': elif categorical_feature == "auto":
return self return self
else: else:
if self.categorical_feature != 'auto': if self.categorical_feature != "auto":
_log_warning('categorical_feature in Dataset is overridden.\n' _log_warning(
f'New categorical_feature is {list(categorical_feature)}') "categorical_feature in Dataset is overridden.\n"
f"New categorical_feature is {list(categorical_feature)}"
)
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
return self._free_handle() return self._free_handle()
else: else:
raise LightGBMError("Cannot set categorical feature after freed raw data, " raise LightGBMError(
"set free_raw_data=False when construct Dataset to avoid this.") "Cannot set categorical feature after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this."
)
def _set_predictor( def _set_predictor(
self, self,
predictor: Optional[_InnerPredictor] predictor: Optional[_InnerPredictor],
) -> "Dataset": ) -> "Dataset":
"""Set predictor for continued training. """Set predictor for continued training.
...@@ -2808,7 +2940,9 @@ class Dataset: ...@@ -2808,7 +2940,9 @@ class Dataset:
if predictor is None and self._predictor is None: if predictor is None and self._predictor is None:
return self return self
elif isinstance(predictor, _InnerPredictor) and isinstance(self._predictor, _InnerPredictor): elif isinstance(predictor, _InnerPredictor) and isinstance(self._predictor, _InnerPredictor):
if (predictor == self._predictor) and (predictor.current_iteration() == self._predictor.current_iteration()): if (predictor == self._predictor) and (
predictor.current_iteration() == self._predictor.current_iteration()
):
return self return self
if self._handle is None: if self._handle is None:
self._predictor = predictor self._predictor = predictor
...@@ -2817,18 +2951,20 @@ class Dataset: ...@@ -2817,18 +2951,20 @@ class Dataset:
self._set_init_score_by_predictor( self._set_init_score_by_predictor(
predictor=self._predictor, predictor=self._predictor,
data=self.data, data=self.data,
used_indices=None used_indices=None,
) )
elif self.used_indices is not None and self.reference is not None and self.reference.data is not None: elif self.used_indices is not None and self.reference is not None and self.reference.data is not None:
self._predictor = predictor self._predictor = predictor
self._set_init_score_by_predictor( self._set_init_score_by_predictor(
predictor=self._predictor, predictor=self._predictor,
data=self.reference.data, data=self.reference.data,
used_indices=self.used_indices used_indices=self.used_indices,
) )
else: else:
raise LightGBMError("Cannot set predictor after freed raw data, " raise LightGBMError(
"set free_raw_data=False when construct Dataset to avoid this.") "Cannot set predictor after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this."
)
return self return self
def set_reference(self, reference: "Dataset") -> "Dataset": def set_reference(self, reference: "Dataset") -> "Dataset":
...@@ -2844,9 +2980,9 @@ class Dataset: ...@@ -2844,9 +2980,9 @@ class Dataset:
self : Dataset self : Dataset
Dataset with set reference. Dataset with set reference.
""" """
self.set_categorical_feature(reference.categorical_feature) \ self.set_categorical_feature(reference.categorical_feature).set_feature_name(
.set_feature_name(reference.feature_name) \ reference.feature_name
._set_predictor(reference._predictor) )._set_predictor(reference._predictor)
# we're done if self and reference share a common upstream reference # we're done if self and reference share a common upstream reference
if self.get_ref_chain().intersection(reference.get_ref_chain()): if self.get_ref_chain().intersection(reference.get_ref_chain()):
return self return self
...@@ -2854,8 +2990,10 @@ class Dataset: ...@@ -2854,8 +2990,10 @@ class Dataset:
self.reference = reference self.reference = reference
return self._free_handle() return self._free_handle()
else: else:
raise LightGBMError("Cannot set reference after freed raw data, " raise LightGBMError(
"set free_raw_data=False when construct Dataset to avoid this.") "Cannot set reference after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this."
)
def set_feature_name(self, feature_name: _LGBM_FeatureNameConfiguration) -> "Dataset": def set_feature_name(self, feature_name: _LGBM_FeatureNameConfiguration) -> "Dataset":
"""Set feature name. """Set feature name.
...@@ -2870,16 +3008,21 @@ class Dataset: ...@@ -2870,16 +3008,21 @@ class Dataset:
self : Dataset self : Dataset
Dataset with set feature name. Dataset with set feature name.
""" """
if feature_name != 'auto': if feature_name != "auto":
self.feature_name = feature_name self.feature_name = feature_name
if self._handle is not None and feature_name is not None and feature_name != 'auto': if self._handle is not None and feature_name is not None and feature_name != "auto":
if len(feature_name) != self.num_feature(): if len(feature_name) != self.num_feature():
raise ValueError(f"Length of feature_name({len(feature_name)}) and num_feature({self.num_feature()}) don't match") raise ValueError(
f"Length of feature_name({len(feature_name)}) and num_feature({self.num_feature()}) don't match"
)
c_feature_name = [_c_str(name) for name in feature_name] c_feature_name = [_c_str(name) for name in feature_name]
_safe_call(_LIB.LGBM_DatasetSetFeatureNames( _safe_call(
self._handle, _LIB.LGBM_DatasetSetFeatureNames(
_c_array(ctypes.c_char_p, c_feature_name), self._handle,
ctypes.c_int(len(feature_name)))) _c_array(ctypes.c_char_p, c_feature_name),
ctypes.c_int(len(feature_name)),
)
)
return self return self
def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset": def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
...@@ -2899,19 +3042,19 @@ class Dataset: ...@@ -2899,19 +3042,19 @@ class Dataset:
if self._handle is not None: if self._handle is not None:
if isinstance(label, pd_DataFrame): if isinstance(label, pd_DataFrame):
if len(label.columns) > 1: if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns') raise ValueError("DataFrame for label cannot have multiple columns")
label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32)) label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32))
elif _is_pyarrow_array(label): elif _is_pyarrow_array(label):
label_array = label label_array = label
else: else:
label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label') label_array = _list_to_1d_numpy(label, dtype=np.float32, name="label")
self.set_field('label', label_array) self.set_field("label", label_array)
self.label = self.get_field('label') # original values can be modified at cpp side self.label = self.get_field("label") # original values can be modified at cpp side
return self return self
def set_weight( def set_weight(
self, self,
weight: Optional[_LGBM_WeightType] weight: Optional[_LGBM_WeightType],
) -> "Dataset": ) -> "Dataset":
"""Set weight of each instance. """Set weight of each instance.
...@@ -2937,14 +3080,14 @@ class Dataset: ...@@ -2937,14 +3080,14 @@ class Dataset:
# Set field # Set field
if self._handle is not None and weight is not None: if self._handle is not None and weight is not None:
if not _is_pyarrow_array(weight): if not _is_pyarrow_array(weight):
weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight') weight = _list_to_1d_numpy(weight, dtype=np.float32, name="weight")
self.set_field('weight', weight) self.set_field("weight", weight)
self.weight = self.get_field('weight') # original values can be modified at cpp side self.weight = self.get_field("weight") # original values can be modified at cpp side
return self return self
def set_init_score( def set_init_score(
self, self,
init_score: Optional[_LGBM_InitScoreType] init_score: Optional[_LGBM_InitScoreType],
) -> "Dataset": ) -> "Dataset":
"""Set init score of Booster to start from. """Set init score of Booster to start from.
...@@ -2960,13 +3103,13 @@ class Dataset: ...@@ -2960,13 +3103,13 @@ class Dataset:
""" """
self.init_score = init_score self.init_score = init_score
if self._handle is not None and init_score is not None: if self._handle is not None and init_score is not None:
self.set_field('init_score', init_score) self.set_field("init_score", init_score)
self.init_score = self.get_field('init_score') # original values can be modified at cpp side self.init_score = self.get_field("init_score") # original values can be modified at cpp side
return self return self
def set_group( def set_group(
self, self,
group: Optional[_LGBM_GroupType] group: Optional[_LGBM_GroupType],
) -> "Dataset": ) -> "Dataset":
"""Set group size of Dataset (used for ranking). """Set group size of Dataset (used for ranking).
...@@ -2987,17 +3130,17 @@ class Dataset: ...@@ -2987,17 +3130,17 @@ class Dataset:
self.group = group self.group = group
if self._handle is not None and group is not None: if self._handle is not None and group is not None:
if not _is_pyarrow_array(group): if not _is_pyarrow_array(group):
group = _list_to_1d_numpy(group, dtype=np.int32, name='group') group = _list_to_1d_numpy(group, dtype=np.int32, name="group")
self.set_field('group', group) self.set_field("group", group)
# original values can be modified at cpp side # original values can be modified at cpp side
constructed_group = self.get_field('group') constructed_group = self.get_field("group")
if constructed_group is not None: if constructed_group is not None:
self.group = np.diff(constructed_group) self.group = np.diff(constructed_group)
return self return self
def set_position( def set_position(
self, self,
position: Optional[_LGBM_PositionType] position: Optional[_LGBM_PositionType],
) -> "Dataset": ) -> "Dataset":
"""Set position of Dataset (used for ranking). """Set position of Dataset (used for ranking).
...@@ -3013,8 +3156,8 @@ class Dataset: ...@@ -3013,8 +3156,8 @@ class Dataset:
""" """
self.position = position self.position = position
if self._handle is not None and position is not None: if self._handle is not None and position is not None:
position = _list_to_1d_numpy(position, dtype=np.int32, name='position') position = _list_to_1d_numpy(position, dtype=np.int32, name="position")
self.set_field('position', position) self.set_field("position", position)
return self return self
def get_feature_name(self) -> List[str]: def get_feature_name(self) -> List[str]:
...@@ -3033,13 +3176,16 @@ class Dataset: ...@@ -3033,13 +3176,16 @@ class Dataset:
required_string_buffer_size = ctypes.c_size_t(0) required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)] string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc] ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc]
_safe_call(_LIB.LGBM_DatasetGetFeatureNames( _safe_call(
self._handle, _LIB.LGBM_DatasetGetFeatureNames(
ctypes.c_int(num_feature), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int(num_feature),
ctypes.c_size_t(reserved_string_buffer_size), ctypes.byref(tmp_out_len),
ctypes.byref(required_string_buffer_size), ctypes.c_size_t(reserved_string_buffer_size),
ptr_string_buffers)) ctypes.byref(required_string_buffer_size),
ptr_string_buffers,
)
)
if num_feature != tmp_out_len.value: if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature") raise ValueError("Length of feature names doesn't equal with num_feature")
actual_string_buffer_size = required_string_buffer_size.value actual_string_buffer_size = required_string_buffer_size.value
...@@ -3047,14 +3193,17 @@ class Dataset: ...@@ -3047,14 +3193,17 @@ class Dataset:
if reserved_string_buffer_size < actual_string_buffer_size: if reserved_string_buffer_size < actual_string_buffer_size:
string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)] string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc] ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc]
_safe_call(_LIB.LGBM_DatasetGetFeatureNames( _safe_call(
self._handle, _LIB.LGBM_DatasetGetFeatureNames(
ctypes.c_int(num_feature), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int(num_feature),
ctypes.c_size_t(actual_string_buffer_size), ctypes.byref(tmp_out_len),
ctypes.byref(required_string_buffer_size), ctypes.c_size_t(actual_string_buffer_size),
ptr_string_buffers)) ctypes.byref(required_string_buffer_size),
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)] ptr_string_buffers,
)
)
return [string_buffers[i].value.decode("utf-8") for i in range(num_feature)]
def get_label(self) -> Optional[_LGBM_LabelType]: def get_label(self) -> Optional[_LGBM_LabelType]:
"""Get the label of the Dataset. """Get the label of the Dataset.
...@@ -3066,7 +3215,7 @@ class Dataset: ...@@ -3066,7 +3215,7 @@ class Dataset:
For a constructed ``Dataset``, this will only return a numpy array. For a constructed ``Dataset``, this will only return a numpy array.
""" """
if self.label is None: if self.label is None:
self.label = self.get_field('label') self.label = self.get_field("label")
return self.label return self.label
def get_weight(self) -> Optional[_LGBM_WeightType]: def get_weight(self) -> Optional[_LGBM_WeightType]:
...@@ -3079,7 +3228,7 @@ class Dataset: ...@@ -3079,7 +3228,7 @@ class Dataset:
For a constructed ``Dataset``, this will only return ``None`` or a numpy array. For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
""" """
if self.weight is None: if self.weight is None:
self.weight = self.get_field('weight') self.weight = self.get_field("weight")
return self.weight return self.weight
def get_init_score(self) -> Optional[_LGBM_InitScoreType]: def get_init_score(self) -> Optional[_LGBM_InitScoreType]:
...@@ -3092,7 +3241,7 @@ class Dataset: ...@@ -3092,7 +3241,7 @@ class Dataset:
For a constructed ``Dataset``, this will only return ``None`` or a numpy array. For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
""" """
if self.init_score is None: if self.init_score is None:
self.init_score = self.get_field('init_score') self.init_score = self.get_field("init_score")
return self.init_score return self.init_score
def get_data(self) -> Optional[_LGBM_TrainDataType]: def get_data(self) -> Optional[_LGBM_TrainDataType]:
...@@ -3119,12 +3268,15 @@ class Dataset: ...@@ -3119,12 +3268,15 @@ class Dataset:
elif _is_list_of_sequences(self.data) and len(self.data) > 0: elif _is_list_of_sequences(self.data) and len(self.data) > 0:
self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices))) self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
else: else:
_log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n" _log_warning(
"Returning original raw data") f"Cannot subset {type(self.data).__name__} type of raw data.\n" "Returning original raw data"
)
self._need_slice = False self._need_slice = False
if self.data is None: if self.data is None:
raise LightGBMError("Cannot call `get_data` after freed raw data, " raise LightGBMError(
"set free_raw_data=False when construct Dataset to avoid this.") "Cannot call `get_data` after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this."
)
return self.data return self.data
def get_group(self) -> Optional[_LGBM_GroupType]: def get_group(self) -> Optional[_LGBM_GroupType]:
...@@ -3141,7 +3293,7 @@ class Dataset: ...@@ -3141,7 +3293,7 @@ class Dataset:
For a constructed ``Dataset``, this will only return ``None`` or a numpy array. For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
""" """
if self.group is None: if self.group is None:
self.group = self.get_field('group') self.group = self.get_field("group")
if self.group is not None: if self.group is not None:
# group data from LightGBM is boundaries data, need to convert to group size # group data from LightGBM is boundaries data, need to convert to group size
self.group = np.diff(self.group) self.group = np.diff(self.group)
...@@ -3157,7 +3309,7 @@ class Dataset: ...@@ -3157,7 +3309,7 @@ class Dataset:
For a constructed ``Dataset``, this will only return ``None`` or a numpy array. For a constructed ``Dataset``, this will only return ``None`` or a numpy array.
""" """
if self.position is None: if self.position is None:
self.position = self.get_field('position') self.position = self.get_field("position")
return self.position return self.position
def num_data(self) -> int: def num_data(self) -> int:
...@@ -3170,8 +3322,12 @@ class Dataset: ...@@ -3170,8 +3322,12 @@ class Dataset:
""" """
if self._handle is not None: if self._handle is not None:
ret = ctypes.c_int(0) ret = ctypes.c_int(0)
_safe_call(_LIB.LGBM_DatasetGetNumData(self._handle, _safe_call(
ctypes.byref(ret))) _LIB.LGBM_DatasetGetNumData(
self._handle,
ctypes.byref(ret),
)
)
return ret.value return ret.value
else: else:
raise LightGBMError("Cannot get num_data before construct dataset") raise LightGBMError("Cannot get num_data before construct dataset")
...@@ -3186,8 +3342,12 @@ class Dataset: ...@@ -3186,8 +3342,12 @@ class Dataset:
""" """
if self._handle is not None: if self._handle is not None:
ret = ctypes.c_int(0) ret = ctypes.c_int(0)
_safe_call(_LIB.LGBM_DatasetGetNumFeature(self._handle, _safe_call(
ctypes.byref(ret))) _LIB.LGBM_DatasetGetNumFeature(
self._handle,
ctypes.byref(ret),
)
)
return ret.value return ret.value
else: else:
raise LightGBMError("Cannot get num_feature before construct dataset") raise LightGBMError("Cannot get num_feature before construct dataset")
...@@ -3213,9 +3373,13 @@ class Dataset: ...@@ -3213,9 +3373,13 @@ class Dataset:
else: else:
feature_index = feature feature_index = feature
ret = ctypes.c_int(0) ret = ctypes.c_int(0)
_safe_call(_LIB.LGBM_DatasetGetFeatureNumBin(self._handle, _safe_call(
ctypes.c_int(feature_index), _LIB.LGBM_DatasetGetFeatureNumBin(
ctypes.byref(ret))) self._handle,
ctypes.c_int(feature_index),
ctypes.byref(ret),
)
)
return ret.value return ret.value
else: else:
raise LightGBMError("Cannot get feature_num_bin before construct dataset") raise LightGBMError("Cannot get feature_num_bin before construct dataset")
...@@ -3266,8 +3430,13 @@ class Dataset: ...@@ -3266,8 +3430,13 @@ class Dataset:
Dataset with the new features added. Dataset with the new features added.
""" """
if self._handle is None or other._handle is None: if self._handle is None or other._handle is None:
raise ValueError('Both source and target Datasets must be constructed before adding features') raise ValueError("Both source and target Datasets must be constructed before adding features")
_safe_call(_LIB.LGBM_DatasetAddFeaturesFrom(self._handle, other._handle)) _safe_call(
_LIB.LGBM_DatasetAddFeaturesFrom(
self._handle,
other._handle,
)
)
was_none = self.data is None was_none = self.data is None
old_self_data_type = type(self.data).__name__ old_self_data_type = type(self.data).__name__
if other.data is None: if other.data is None:
...@@ -3296,21 +3465,19 @@ class Dataset: ...@@ -3296,21 +3465,19 @@ class Dataset:
self.data = None self.data = None
elif isinstance(self.data, pd_DataFrame): elif isinstance(self.data, pd_DataFrame):
if not PANDAS_INSTALLED: if not PANDAS_INSTALLED:
raise LightGBMError("Cannot add features to DataFrame type of raw data " raise LightGBMError(
"without pandas installed. " "Cannot add features to DataFrame type of raw data "
"Install pandas and restart your session.") "without pandas installed. "
"Install pandas and restart your session."
)
if isinstance(other.data, np.ndarray): if isinstance(other.data, np.ndarray):
self.data = concat((self.data, pd_DataFrame(other.data)), self.data = concat((self.data, pd_DataFrame(other.data)), axis=1, ignore_index=True)
axis=1, ignore_index=True)
elif isinstance(other.data, scipy.sparse.spmatrix): elif isinstance(other.data, scipy.sparse.spmatrix):
self.data = concat((self.data, pd_DataFrame(other.data.toarray())), self.data = concat((self.data, pd_DataFrame(other.data.toarray())), axis=1, ignore_index=True)
axis=1, ignore_index=True)
elif isinstance(other.data, pd_DataFrame): elif isinstance(other.data, pd_DataFrame):
self.data = concat((self.data, other.data), self.data = concat((self.data, other.data), axis=1, ignore_index=True)
axis=1, ignore_index=True)
elif isinstance(other.data, dt_DataTable): elif isinstance(other.data, dt_DataTable):
self.data = concat((self.data, pd_DataFrame(other.data.to_numpy())), self.data = concat((self.data, pd_DataFrame(other.data.to_numpy())), axis=1, ignore_index=True)
axis=1, ignore_index=True)
else: else:
self.data = None self.data = None
elif isinstance(self.data, dt_DataTable): elif isinstance(self.data, dt_DataTable):
...@@ -3327,14 +3494,19 @@ class Dataset: ...@@ -3327,14 +3494,19 @@ class Dataset:
else: else:
self.data = None self.data = None
if self.data is None: if self.data is None:
err_msg = (f"Cannot add features from {type(other.data).__name__} type of raw data to " err_msg = (
f"{old_self_data_type} type of raw data.\n") f"Cannot add features from {type(other.data).__name__} type of raw data to "
err_msg += ("Set free_raw_data=False when construct Dataset to avoid this" f"{old_self_data_type} type of raw data.\n"
if was_none else "Freeing raw data") )
err_msg += (
"Set free_raw_data=False when construct Dataset to avoid this" if was_none else "Freeing raw data"
)
_log_warning(err_msg) _log_warning(err_msg)
self.feature_name = self.get_feature_name() self.feature_name = self.get_feature_name()
_log_warning("Reseting categorical features.\n" _log_warning(
"You can set new categorical features via ``set_categorical_feature`` method") "Reseting categorical features.\n"
"You can set new categorical features via ``set_categorical_feature`` method"
)
self.categorical_feature = "auto" self.categorical_feature = "auto"
self.pandas_categorical = None self.pandas_categorical = None
return self return self
...@@ -3354,25 +3526,28 @@ class Dataset: ...@@ -3354,25 +3526,28 @@ class Dataset:
self : Dataset self : Dataset
Returns self. Returns self.
""" """
_safe_call(_LIB.LGBM_DatasetDumpText( _safe_call(
self.construct()._handle, _LIB.LGBM_DatasetDumpText(
_c_str(str(filename)))) self.construct()._handle,
_c_str(str(filename)),
)
)
return self return self
_LGBM_CustomObjectiveFunction = Callable[ _LGBM_CustomObjectiveFunction = Callable[
[np.ndarray, Dataset], [np.ndarray, Dataset],
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray],
] ]
_LGBM_CustomEvalFunction = Union[ _LGBM_CustomEvalFunction = Union[
Callable[ Callable[
[np.ndarray, Dataset], [np.ndarray, Dataset],
_LGBM_EvalFunctionResultType _LGBM_EvalFunctionResultType,
], ],
Callable[ Callable[
[np.ndarray, Dataset], [np.ndarray, Dataset],
List[_LGBM_EvalFunctionResultType] List[_LGBM_EvalFunctionResultType],
] ],
] ]
...@@ -3384,7 +3559,7 @@ class Booster: ...@@ -3384,7 +3559,7 @@ class Booster:
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
train_set: Optional[Dataset] = None, train_set: Optional[Dataset] = None,
model_file: Optional[Union[str, Path]] = None, model_file: Optional[Union[str, Path]] = None,
model_str: Optional[str] = None model_str: Optional[str] = None,
): ):
"""Initialize the Booster. """Initialize the Booster.
...@@ -3410,11 +3585,11 @@ class Booster: ...@@ -3410,11 +3585,11 @@ class Booster:
if train_set is not None: if train_set is not None:
# Training task # Training task
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError(f'Training data should be Dataset instance, met {type(train_set).__name__}') raise TypeError(f"Training data should be Dataset instance, met {type(train_set).__name__}")
params = _choose_param_value( params = _choose_param_value(
main_param_name="machines", main_param_name="machines",
params=params, params=params,
default_value=None default_value=None,
) )
# if "machines" is given, assume user wants to do distributed learning, and set up network # if "machines" is given, assume user wants to do distributed learning, and set up network
if params["machines"] is None: if params["machines"] is None:
...@@ -3422,38 +3597,41 @@ class Booster: ...@@ -3422,38 +3597,41 @@ class Booster:
else: else:
machines = params["machines"] machines = params["machines"]
if isinstance(machines, str): if isinstance(machines, str):
num_machines_from_machine_list = len(machines.split(',')) num_machines_from_machine_list = len(machines.split(","))
elif isinstance(machines, (list, set)): elif isinstance(machines, (list, set)):
num_machines_from_machine_list = len(machines) num_machines_from_machine_list = len(machines)
machines = ','.join(machines) machines = ",".join(machines)
else: else:
raise ValueError("Invalid machines in params.") raise ValueError("Invalid machines in params.")
params = _choose_param_value( params = _choose_param_value(
main_param_name="num_machines", main_param_name="num_machines",
params=params, params=params,
default_value=num_machines_from_machine_list default_value=num_machines_from_machine_list,
) )
params = _choose_param_value( params = _choose_param_value(
main_param_name="local_listen_port", main_param_name="local_listen_port",
params=params, params=params,
default_value=12400 default_value=12400,
) )
self.set_network( self.set_network(
machines=machines, machines=machines,
local_listen_port=params["local_listen_port"], local_listen_port=params["local_listen_port"],
listen_time_out=params.get("time_out", 120), listen_time_out=params.get("time_out", 120),
num_machines=params["num_machines"] num_machines=params["num_machines"],
) )
# construct booster object # construct booster object
train_set.construct() train_set.construct()
# copy the parameters from train_set # copy the parameters from train_set
params.update(train_set.get_params()) params.update(train_set.get_params())
params_str = _param_dict_to_str(params) params_str = _param_dict_to_str(params)
_safe_call(_LIB.LGBM_BoosterCreate( _safe_call(
train_set._handle, _LIB.LGBM_BoosterCreate(
_c_str(params_str), train_set._handle,
ctypes.byref(self._handle))) _c_str(params_str),
ctypes.byref(self._handle),
)
)
# save reference to data # save reference to data
self.train_set = train_set self.train_set = train_set
self.valid_sets: List[Dataset] = [] self.valid_sets: List[Dataset] = []
...@@ -3461,13 +3639,19 @@ class Booster: ...@@ -3461,13 +3639,19 @@ class Booster:
self.__num_dataset = 1 self.__num_dataset = 1
self.__init_predictor = train_set._predictor self.__init_predictor = train_set._predictor
if self.__init_predictor is not None: if self.__init_predictor is not None:
_safe_call(_LIB.LGBM_BoosterMerge( _safe_call(
self._handle, _LIB.LGBM_BoosterMerge(
self.__init_predictor._handle)) self._handle,
self.__init_predictor._handle,
)
)
out_num_class = ctypes.c_int(0) out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(
self._handle, _LIB.LGBM_BoosterGetNumClasses(
ctypes.byref(out_num_class))) self._handle,
ctypes.byref(out_num_class),
)
)
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
# buffer for inner predict # buffer for inner predict
self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None] self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None]
...@@ -3478,24 +3662,31 @@ class Booster: ...@@ -3478,24 +3662,31 @@ class Booster:
elif model_file is not None: elif model_file is not None:
# Prediction task # Prediction task
out_num_iterations = ctypes.c_int(0) out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile( _safe_call(
_c_str(str(model_file)), _LIB.LGBM_BoosterCreateFromModelfile(
ctypes.byref(out_num_iterations), _c_str(str(model_file)),
ctypes.byref(self._handle))) ctypes.byref(out_num_iterations),
ctypes.byref(self._handle),
)
)
out_num_class = ctypes.c_int(0) out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(
self._handle, _LIB.LGBM_BoosterGetNumClasses(
ctypes.byref(out_num_class))) self._handle,
ctypes.byref(out_num_class),
)
)
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file) self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
if params: if params:
_log_warning('Ignoring params argument, using parameters from model file.') _log_warning("Ignoring params argument, using parameters from model file.")
params = self._get_loaded_param() params = self._get_loaded_param()
elif model_str is not None: elif model_str is not None:
self.model_from_string(model_str) self.model_from_string(model_str)
else: else:
raise TypeError('Need at least one training dataset or model file or model string ' raise TypeError(
'to create Booster instance') "Need at least one training dataset or model file or model string " "to create Booster instance"
)
self.params = params self.params = params
def __del__(self) -> None: def __del__(self) -> None:
...@@ -3519,23 +3710,26 @@ class Booster: ...@@ -3519,23 +3710,26 @@ class Booster:
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:
this = self.__dict__.copy() this = self.__dict__.copy()
handle = this['_handle'] handle = this["_handle"]
this.pop('train_set', None) this.pop("train_set", None)
this.pop('valid_sets', None) this.pop("valid_sets", None)
if handle is not None: if handle is not None:
this["_handle"] = self.model_to_string(num_iteration=-1) this["_handle"] = self.model_to_string(num_iteration=-1)
return this return this
def __setstate__(self, state: Dict[str, Any]) -> None: def __setstate__(self, state: Dict[str, Any]) -> None:
model_str = state.get('_handle', state.get('handle', None)) model_str = state.get("_handle", state.get("handle", None))
if model_str is not None: if model_str is not None:
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int(0) out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterLoadModelFromString( _safe_call(
_c_str(model_str), _LIB.LGBM_BoosterLoadModelFromString(
ctypes.byref(out_num_iterations), _c_str(model_str),
ctypes.byref(handle))) ctypes.byref(out_num_iterations),
state['_handle'] = handle ctypes.byref(handle),
)
)
state["_handle"] = handle
self.__dict__.update(state) self.__dict__.update(state)
def _get_loaded_param(self) -> Dict[str, Any]: def _get_loaded_param(self) -> Dict[str, Any]:
...@@ -3543,22 +3737,28 @@ class Booster: ...@@ -3543,22 +3737,28 @@ class Booster:
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len) string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer)) ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(_LIB.LGBM_BoosterGetLoadedParam( _safe_call(
self._handle, _LIB.LGBM_BoosterGetLoadedParam(
ctypes.c_int64(buffer_len), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int64(buffer_len),
ptr_string_buffer)) ctypes.byref(tmp_out_len),
ptr_string_buffer,
)
)
actual_len = tmp_out_len.value actual_len = tmp_out_len.value
# if buffer length is not long enough, re-allocate a buffer # if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len: if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len) string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer)) ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(_LIB.LGBM_BoosterGetLoadedParam( _safe_call(
self._handle, _LIB.LGBM_BoosterGetLoadedParam(
ctypes.c_int64(actual_len), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int64(actual_len),
ptr_string_buffer)) ctypes.byref(tmp_out_len),
return json.loads(string_buffer.value.decode('utf-8')) ptr_string_buffer,
)
)
return json.loads(string_buffer.value.decode("utf-8"))
def free_dataset(self) -> "Booster": def free_dataset(self) -> "Booster":
"""Free Booster's Datasets. """Free Booster's Datasets.
...@@ -3568,8 +3768,8 @@ class Booster: ...@@ -3568,8 +3768,8 @@ class Booster:
self : Booster self : Booster
Booster without Datasets. Booster without Datasets.
""" """
self.__dict__.pop('train_set', None) self.__dict__.pop("train_set", None)
self.__dict__.pop('valid_sets', None) self.__dict__.pop("valid_sets", None)
self.__num_dataset = 0 self.__num_dataset = 0
return self return self
...@@ -3583,7 +3783,7 @@ class Booster: ...@@ -3583,7 +3783,7 @@ class Booster:
machines: Union[List[str], Set[str], str], machines: Union[List[str], Set[str], str],
local_listen_port: int = 12400, local_listen_port: int = 12400,
listen_time_out: int = 120, listen_time_out: int = 120,
num_machines: int = 1 num_machines: int = 1,
) -> "Booster": ) -> "Booster":
"""Set the network configuration. """Set the network configuration.
...@@ -3604,11 +3804,15 @@ class Booster: ...@@ -3604,11 +3804,15 @@ class Booster:
Booster with set network. Booster with set network.
""" """
if isinstance(machines, (list, set)): if isinstance(machines, (list, set)):
machines = ','.join(machines) machines = ",".join(machines)
_safe_call(_LIB.LGBM_NetworkInit(_c_str(machines), _safe_call(
ctypes.c_int(local_listen_port), _LIB.LGBM_NetworkInit(
ctypes.c_int(listen_time_out), _c_str(machines),
ctypes.c_int(num_machines))) ctypes.c_int(local_listen_port),
ctypes.c_int(listen_time_out),
ctypes.c_int(num_machines),
)
)
self._network = True self._network = True
return self return self
...@@ -3653,85 +3857,86 @@ class Booster: ...@@ -3653,85 +3857,86 @@ class Booster:
Returns a pandas DataFrame of the parsed model. Returns a pandas DataFrame of the parsed model.
""" """
if not PANDAS_INSTALLED: if not PANDAS_INSTALLED:
raise LightGBMError('This method cannot be run without pandas installed. ' raise LightGBMError(
'You must install pandas and restart your session to use this method.') "This method cannot be run without pandas installed. "
"You must install pandas and restart your session to use this method."
)
if self.num_trees() == 0: if self.num_trees() == 0:
raise LightGBMError('There are no trees in this Booster and thus nothing to parse') raise LightGBMError("There are no trees in this Booster and thus nothing to parse")
def _is_split_node(tree: Dict[str, Any]) -> bool: def _is_split_node(tree: Dict[str, Any]) -> bool:
return 'split_index' in tree.keys() return "split_index" in tree.keys()
def create_node_record( def create_node_record(
tree: Dict[str, Any], tree: Dict[str, Any],
node_depth: int = 1, node_depth: int = 1,
tree_index: Optional[int] = None, tree_index: Optional[int] = None,
feature_names: Optional[List[str]] = None, feature_names: Optional[List[str]] = None,
parent_node: Optional[str] = None parent_node: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
def _get_node_index( def _get_node_index(
tree: Dict[str, Any], tree: Dict[str, Any],
tree_index: Optional[int] tree_index: Optional[int],
) -> str: ) -> str:
tree_num = f'{tree_index}-' if tree_index is not None else '' tree_num = f"{tree_index}-" if tree_index is not None else ""
is_split = _is_split_node(tree) is_split = _is_split_node(tree)
node_type = 'S' if is_split else 'L' node_type = "S" if is_split else "L"
# if a single node tree it won't have `leaf_index` so return 0 # if a single node tree it won't have `leaf_index` so return 0
node_num = tree.get('split_index' if is_split else 'leaf_index', 0) node_num = tree.get("split_index" if is_split else "leaf_index", 0)
return f"{tree_num}{node_type}{node_num}" return f"{tree_num}{node_type}{node_num}"
def _get_split_feature( def _get_split_feature(
tree: Dict[str, Any], tree: Dict[str, Any],
feature_names: Optional[List[str]] feature_names: Optional[List[str]],
) -> Optional[str]: ) -> Optional[str]:
if _is_split_node(tree): if _is_split_node(tree):
if feature_names is not None: if feature_names is not None:
feature_name = feature_names[tree['split_feature']] feature_name = feature_names[tree["split_feature"]]
else: else:
feature_name = tree['split_feature'] feature_name = tree["split_feature"]
else: else:
feature_name = None feature_name = None
return feature_name return feature_name
def _is_single_node_tree(tree: Dict[str, Any]) -> bool: def _is_single_node_tree(tree: Dict[str, Any]) -> bool:
return set(tree.keys()) == {'leaf_value'} return set(tree.keys()) == {"leaf_value"}
# Create the node record, and populate universal data members # Create the node record, and populate universal data members
node: Dict[str, Union[int, str, None]] = OrderedDict() node: Dict[str, Union[int, str, None]] = OrderedDict()
node['tree_index'] = tree_index node["tree_index"] = tree_index
node['node_depth'] = node_depth node["node_depth"] = node_depth
node['node_index'] = _get_node_index(tree, tree_index) node["node_index"] = _get_node_index(tree, tree_index)
node['left_child'] = None node["left_child"] = None
node['right_child'] = None node["right_child"] = None
node['parent_index'] = parent_node node["parent_index"] = parent_node
node['split_feature'] = _get_split_feature(tree, feature_names) node["split_feature"] = _get_split_feature(tree, feature_names)
node['split_gain'] = None node["split_gain"] = None
node['threshold'] = None node["threshold"] = None
node['decision_type'] = None node["decision_type"] = None
node['missing_direction'] = None node["missing_direction"] = None
node['missing_type'] = None node["missing_type"] = None
node['value'] = None node["value"] = None
node['weight'] = None node["weight"] = None
node['count'] = None node["count"] = None
# Update values to reflect node type (leaf or split) # Update values to reflect node type (leaf or split)
if _is_split_node(tree): if _is_split_node(tree):
node['left_child'] = _get_node_index(tree['left_child'], tree_index) node["left_child"] = _get_node_index(tree["left_child"], tree_index)
node['right_child'] = _get_node_index(tree['right_child'], tree_index) node["right_child"] = _get_node_index(tree["right_child"], tree_index)
node['split_gain'] = tree['split_gain'] node["split_gain"] = tree["split_gain"]
node['threshold'] = tree['threshold'] node["threshold"] = tree["threshold"]
node['decision_type'] = tree['decision_type'] node["decision_type"] = tree["decision_type"]
node['missing_direction'] = 'left' if tree['default_left'] else 'right' node["missing_direction"] = "left" if tree["default_left"] else "right"
node['missing_type'] = tree['missing_type'] node["missing_type"] = tree["missing_type"]
node['value'] = tree['internal_value'] node["value"] = tree["internal_value"]
node['weight'] = tree['internal_weight'] node["weight"] = tree["internal_weight"]
node['count'] = tree['internal_count'] node["count"] = tree["internal_count"]
else: else:
node['value'] = tree['leaf_value'] node["value"] = tree["leaf_value"]
if not _is_single_node_tree(tree): if not _is_single_node_tree(tree):
node['weight'] = tree['leaf_weight'] node["weight"] = tree["leaf_weight"]
node['count'] = tree['leaf_count'] node["count"] = tree["leaf_count"]
return node return node
...@@ -3740,27 +3945,28 @@ class Booster: ...@@ -3740,27 +3945,28 @@ class Booster:
node_depth: int = 1, node_depth: int = 1,
tree_index: Optional[int] = None, tree_index: Optional[int] = None,
feature_names: Optional[List[str]] = None, feature_names: Optional[List[str]] = None,
parent_node: Optional[str] = None parent_node: Optional[str] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
node = create_node_record(
node = create_node_record(tree=tree, tree=tree,
node_depth=node_depth, node_depth=node_depth,
tree_index=tree_index, tree_index=tree_index,
feature_names=feature_names, feature_names=feature_names,
parent_node=parent_node) parent_node=parent_node,
)
res = [node] res = [node]
if _is_split_node(tree): if _is_split_node(tree):
# traverse the next level of the tree # traverse the next level of the tree
children = ['left_child', 'right_child'] children = ["left_child", "right_child"]
for child in children: for child in children:
subtree_list = tree_dict_to_node_list( subtree_list = tree_dict_to_node_list(
tree=tree[child], tree=tree[child],
node_depth=node_depth + 1, node_depth=node_depth + 1,
tree_index=tree_index, tree_index=tree_index,
feature_names=feature_names, feature_names=feature_names,
parent_node=node['node_index'] parent_node=node["node_index"],
) )
# In tree format, "subtree_list" is a list of node records (dicts), # In tree format, "subtree_list" is a list of node records (dicts),
# and we add node to the list. # and we add node to the list.
...@@ -3768,12 +3974,14 @@ class Booster: ...@@ -3768,12 +3974,14 @@ class Booster:
return res return res
model_dict = self.dump_model() model_dict = self.dump_model()
feature_names = model_dict['feature_names'] feature_names = model_dict["feature_names"]
model_list = [] model_list = []
for tree in model_dict['tree_info']: for tree in model_dict["tree_info"]:
model_list.extend(tree_dict_to_node_list(tree=tree['tree_structure'], model_list.extend(
tree_index=tree['tree_index'], tree_dict_to_node_list(
feature_names=feature_names)) tree=tree["tree_structure"], tree_index=tree["tree_index"], feature_names=feature_names
)
)
return pd_DataFrame(model_list, columns=model_list[0].keys()) return pd_DataFrame(model_list, columns=model_list[0].keys())
...@@ -3809,13 +4017,15 @@ class Booster: ...@@ -3809,13 +4017,15 @@ class Booster:
Booster with set validation data. Booster with set validation data.
""" """
if not isinstance(data, Dataset): if not isinstance(data, Dataset):
raise TypeError(f'Validation data should be Dataset instance, met {type(data).__name__}') raise TypeError(f"Validation data should be Dataset instance, met {type(data).__name__}")
if data._predictor is not self.__init_predictor: if data._predictor is not self.__init_predictor:
raise LightGBMError("Add validation data failed, " raise LightGBMError("Add validation data failed, " "you should use same predictor for these data")
"you should use same predictor for these data") _safe_call(
_safe_call(_LIB.LGBM_BoosterAddValidData( _LIB.LGBM_BoosterAddValidData(
self._handle, self._handle,
data.construct()._handle)) data.construct()._handle,
)
)
self.valid_sets.append(data) self.valid_sets.append(data)
self.name_valid_sets.append(name) self.name_valid_sets.append(name)
self.__num_dataset += 1 self.__num_dataset += 1
...@@ -3838,16 +4048,19 @@ class Booster: ...@@ -3838,16 +4048,19 @@ class Booster:
""" """
params_str = _param_dict_to_str(params) params_str = _param_dict_to_str(params)
if params_str: if params_str:
_safe_call(_LIB.LGBM_BoosterResetParameter( _safe_call(
self._handle, _LIB.LGBM_BoosterResetParameter(
_c_str(params_str))) self._handle,
_c_str(params_str),
)
)
self.params.update(params) self.params.update(params)
return self return self
def update( def update(
self, self,
train_set: Optional[Dataset] = None, train_set: Optional[Dataset] = None,
fobj: Optional[_LGBM_CustomObjectiveFunction] = None fobj: Optional[_LGBM_CustomObjectiveFunction] = None,
) -> bool: ) -> bool:
"""Update Booster for one iteration. """Update Booster for one iteration.
...@@ -3890,23 +4103,28 @@ class Booster: ...@@ -3890,23 +4103,28 @@ class Booster:
is_the_same_train_set = train_set is self.train_set and self.train_set_version == train_set.version is_the_same_train_set = train_set is self.train_set and self.train_set_version == train_set.version
if train_set is not None and not is_the_same_train_set: if train_set is not None and not is_the_same_train_set:
if not isinstance(train_set, Dataset): if not isinstance(train_set, Dataset):
raise TypeError(f'Training data should be Dataset instance, met {type(train_set).__name__}') raise TypeError(f"Training data should be Dataset instance, met {type(train_set).__name__}")
if train_set._predictor is not self.__init_predictor: if train_set._predictor is not self.__init_predictor:
raise LightGBMError("Replace training data failed, " raise LightGBMError("Replace training data failed, " "you should use same predictor for these data")
"you should use same predictor for these data")
self.train_set = train_set self.train_set = train_set
_safe_call(_LIB.LGBM_BoosterResetTrainingData( _safe_call(
self._handle, _LIB.LGBM_BoosterResetTrainingData(
self.train_set.construct()._handle)) self._handle,
self.train_set.construct()._handle,
)
)
self.__inner_predict_buffer[0] = None self.__inner_predict_buffer[0] = None
self.train_set_version = self.train_set.version self.train_set_version = self.train_set.version
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
if fobj is None: if fobj is None:
if self.__set_objective_to_none: if self.__set_objective_to_none:
raise LightGBMError('Cannot update due to null objective function.') raise LightGBMError("Cannot update due to null objective function.")
_safe_call(_LIB.LGBM_BoosterUpdateOneIter( _safe_call(
self._handle, _LIB.LGBM_BoosterUpdateOneIter(
ctypes.byref(is_finished))) self._handle,
ctypes.byref(is_finished),
)
)
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
return is_finished.value == 1 return is_finished.value == 1
else: else:
...@@ -3918,7 +4136,7 @@ class Booster: ...@@ -3918,7 +4136,7 @@ class Booster:
def __boost( def __boost(
self, self,
grad: np.ndarray, grad: np.ndarray,
hess: np.ndarray hess: np.ndarray,
) -> bool: ) -> bool:
"""Boost Booster for one iteration with customized gradient statistics. """Boost Booster for one iteration with customized gradient statistics.
...@@ -3944,10 +4162,10 @@ class Booster: ...@@ -3944,10 +4162,10 @@ class Booster:
Whether the boost was successfully finished. Whether the boost was successfully finished.
""" """
if self.__num_class > 1: if self.__num_class > 1:
grad = grad.ravel(order='F') grad = grad.ravel(order="F")
hess = hess.ravel(order='F') hess = hess.ravel(order="F")
grad = _list_to_1d_numpy(grad, dtype=np.float32, name='gradient') grad = _list_to_1d_numpy(grad, dtype=np.float32, name="gradient")
hess = _list_to_1d_numpy(hess, dtype=np.float32, name='hessian') hess = _list_to_1d_numpy(hess, dtype=np.float32, name="hessian")
assert grad.flags.c_contiguous assert grad.flags.c_contiguous
assert hess.flags.c_contiguous assert hess.flags.c_contiguous
if len(grad) != len(hess): if len(grad) != len(hess):
...@@ -3960,11 +4178,14 @@ class Booster: ...@@ -3960,11 +4178,14 @@ class Booster:
f"number of models per one iteration ({self.__num_class})" f"number of models per one iteration ({self.__num_class})"
) )
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom( _safe_call(
self._handle, _LIB.LGBM_BoosterUpdateOneIterCustom(
grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), self._handle,
hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), grad.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.byref(is_finished))) hess.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
ctypes.byref(is_finished),
)
)
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
return is_finished.value == 1 return is_finished.value == 1
...@@ -3976,8 +4197,7 @@ class Booster: ...@@ -3976,8 +4197,7 @@ class Booster:
self : Booster self : Booster
Booster with rolled back one iteration. Booster with rolled back one iteration.
""" """
_safe_call(_LIB.LGBM_BoosterRollbackOneIter( _safe_call(_LIB.LGBM_BoosterRollbackOneIter(self._handle))
self._handle))
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
return self return self
...@@ -3990,9 +4210,12 @@ class Booster: ...@@ -3990,9 +4210,12 @@ class Booster:
The index of the current iteration. The index of the current iteration.
""" """
out_cur_iter = ctypes.c_int(0) out_cur_iter = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration( _safe_call(
self._handle, _LIB.LGBM_BoosterGetCurrentIteration(
ctypes.byref(out_cur_iter))) self._handle,
ctypes.byref(out_cur_iter),
)
)
return out_cur_iter.value return out_cur_iter.value
def num_model_per_iteration(self) -> int: def num_model_per_iteration(self) -> int:
...@@ -4004,9 +4227,12 @@ class Booster: ...@@ -4004,9 +4227,12 @@ class Booster:
The number of models per iteration. The number of models per iteration.
""" """
model_per_iter = ctypes.c_int(0) model_per_iter = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterNumModelPerIteration( _safe_call(
self._handle, _LIB.LGBM_BoosterNumModelPerIteration(
ctypes.byref(model_per_iter))) self._handle,
ctypes.byref(model_per_iter),
)
)
return model_per_iter.value return model_per_iter.value
def num_trees(self) -> int: def num_trees(self) -> int:
...@@ -4018,9 +4244,12 @@ class Booster: ...@@ -4018,9 +4244,12 @@ class Booster:
The number of weak sub-models. The number of weak sub-models.
""" """
num_trees = ctypes.c_int(0) num_trees = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterNumberOfTotalModel( _safe_call(
self._handle, _LIB.LGBM_BoosterNumberOfTotalModel(
ctypes.byref(num_trees))) self._handle,
ctypes.byref(num_trees),
)
)
return num_trees.value return num_trees.value
def upper_bound(self) -> float: def upper_bound(self) -> float:
...@@ -4032,9 +4261,12 @@ class Booster: ...@@ -4032,9 +4261,12 @@ class Booster:
Upper bound value of the model. Upper bound value of the model.
""" """
ret = ctypes.c_double(0) ret = ctypes.c_double(0)
_safe_call(_LIB.LGBM_BoosterGetUpperBoundValue( _safe_call(
self._handle, _LIB.LGBM_BoosterGetUpperBoundValue(
ctypes.byref(ret))) self._handle,
ctypes.byref(ret),
)
)
return ret.value return ret.value
def lower_bound(self) -> float: def lower_bound(self) -> float:
...@@ -4046,16 +4278,19 @@ class Booster: ...@@ -4046,16 +4278,19 @@ class Booster:
Lower bound value of the model. Lower bound value of the model.
""" """
ret = ctypes.c_double(0) ret = ctypes.c_double(0)
_safe_call(_LIB.LGBM_BoosterGetLowerBoundValue( _safe_call(
self._handle, _LIB.LGBM_BoosterGetLowerBoundValue(
ctypes.byref(ret))) self._handle,
ctypes.byref(ret),
)
)
return ret.value return ret.value
def eval( def eval(
self, self,
data: Dataset, data: Dataset,
name: str, name: str,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None,
) -> List[_LGBM_BoosterEvalMethodResultType]: ) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for data. """Evaluate for data.
...@@ -4108,7 +4343,7 @@ class Booster: ...@@ -4108,7 +4343,7 @@ class Booster:
def eval_train( def eval_train(
self, self,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None,
) -> List[_LGBM_BoosterEvalMethodResultType]: ) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for training data. """Evaluate for training data.
...@@ -4142,7 +4377,7 @@ class Booster: ...@@ -4142,7 +4377,7 @@ class Booster:
def eval_valid( def eval_valid(
self, self,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] = None,
) -> List[_LGBM_BoosterEvalMethodResultType]: ) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate for validation data. """Evaluate for validation data.
...@@ -4172,15 +4407,18 @@ class Booster: ...@@ -4172,15 +4407,18 @@ class Booster:
result : list result : list
List with (validation_dataset_name, eval_name, eval_result, is_higher_better) tuples. List with (validation_dataset_name, eval_name, eval_result, is_higher_better) tuples.
""" """
return [item for i in range(1, self.__num_dataset) return [
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)] item
for i in range(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)
]
def save_model( def save_model(
self, self,
filename: Union[str, Path], filename: Union[str, Path],
num_iteration: Optional[int] = None, num_iteration: Optional[int] = None,
start_iteration: int = 0, start_iteration: int = 0,
importance_type: str = 'split' importance_type: str = "split",
) -> "Booster": ) -> "Booster":
"""Save Booster to file. """Save Booster to file.
...@@ -4207,19 +4445,22 @@ class Booster: ...@@ -4207,19 +4445,22 @@ class Booster:
if num_iteration is None: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type] importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
_safe_call(_LIB.LGBM_BoosterSaveModel( _safe_call(
self._handle, _LIB.LGBM_BoosterSaveModel(
ctypes.c_int(start_iteration), self._handle,
ctypes.c_int(num_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(importance_type_int), ctypes.c_int(num_iteration),
_c_str(str(filename)))) ctypes.c_int(importance_type_int),
_c_str(str(filename)),
)
)
_dump_pandas_categorical(self.pandas_categorical, filename) _dump_pandas_categorical(self.pandas_categorical, filename)
return self return self
def shuffle_models( def shuffle_models(
self, self,
start_iteration: int = 0, start_iteration: int = 0,
end_iteration: int = -1 end_iteration: int = -1,
) -> "Booster": ) -> "Booster":
"""Shuffle models. """Shuffle models.
...@@ -4236,10 +4477,13 @@ class Booster: ...@@ -4236,10 +4477,13 @@ class Booster:
self : Booster self : Booster
Booster with shuffled models. Booster with shuffled models.
""" """
_safe_call(_LIB.LGBM_BoosterShuffleModels( _safe_call(
self._handle, _LIB.LGBM_BoosterShuffleModels(
ctypes.c_int(start_iteration), self._handle,
ctypes.c_int(end_iteration))) ctypes.c_int(start_iteration),
ctypes.c_int(end_iteration),
)
)
return self return self
def model_from_string(self, model_str: str) -> "Booster": def model_from_string(self, model_str: str) -> "Booster":
...@@ -4261,14 +4505,20 @@ class Booster: ...@@ -4261,14 +4505,20 @@ class Booster:
self._free_buffer() self._free_buffer()
self._handle = ctypes.c_void_p() self._handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int(0) out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterLoadModelFromString( _safe_call(
_c_str(model_str), _LIB.LGBM_BoosterLoadModelFromString(
ctypes.byref(out_num_iterations), _c_str(model_str),
ctypes.byref(self._handle))) ctypes.byref(out_num_iterations),
ctypes.byref(self._handle),
)
)
out_num_class = ctypes.c_int(0) out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses( _safe_call(
self._handle, _LIB.LGBM_BoosterGetNumClasses(
ctypes.byref(out_num_class))) self._handle,
ctypes.byref(out_num_class),
)
)
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(model_str=model_str) self.pandas_categorical = _load_pandas_categorical(model_str=model_str)
return self return self
...@@ -4277,7 +4527,7 @@ class Booster: ...@@ -4277,7 +4527,7 @@ class Booster:
self, self,
num_iteration: Optional[int] = None, num_iteration: Optional[int] = None,
start_iteration: int = 0, start_iteration: int = 0,
importance_type: str = 'split' importance_type: str = "split",
) -> str: ) -> str:
"""Save Booster to string. """Save Booster to string.
...@@ -4306,28 +4556,34 @@ class Booster: ...@@ -4306,28 +4556,34 @@ class Booster:
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len) string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer)) ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(_LIB.LGBM_BoosterSaveModelToString( _safe_call(
self._handle, _LIB.LGBM_BoosterSaveModelToString(
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
# if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(_LIB.LGBM_BoosterSaveModelToString(
self._handle, self._handle,
ctypes.c_int(start_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int), ctypes.c_int(importance_type_int),
ctypes.c_int64(actual_len), ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer,
ret = string_buffer.value.decode('utf-8') )
)
actual_len = tmp_out_len.value
# if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(
_LIB.LGBM_BoosterSaveModelToString(
self._handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer,
)
)
ret = string_buffer.value.decode("utf-8")
ret += _dump_pandas_categorical(self.pandas_categorical) ret += _dump_pandas_categorical(self.pandas_categorical)
return ret return ret
...@@ -4335,8 +4591,8 @@ class Booster: ...@@ -4335,8 +4591,8 @@ class Booster:
self, self,
num_iteration: Optional[int] = None, num_iteration: Optional[int] = None,
start_iteration: int = 0, start_iteration: int = 0,
importance_type: str = 'split', importance_type: str = "split",
object_hook: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None object_hook: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Dump Booster to JSON format. """Dump Booster to JSON format.
...@@ -4374,30 +4630,40 @@ class Booster: ...@@ -4374,30 +4630,40 @@ class Booster:
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len) string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer)) ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(_LIB.LGBM_BoosterDumpModel( _safe_call(
self._handle, _LIB.LGBM_BoosterDumpModel(
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
# if buffer length is not long enough, reallocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(_LIB.LGBM_BoosterDumpModel(
self._handle, self._handle,
ctypes.c_int(start_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int), ctypes.c_int(importance_type_int),
ctypes.c_int64(actual_len), ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer,
ret = json.loads(string_buffer.value.decode('utf-8'), object_hook=object_hook) )
ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical, )
default=_json_default_with_numpy)) actual_len = tmp_out_len.value
# if buffer length is not long enough, reallocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(ctypes.addressof(string_buffer))
_safe_call(
_LIB.LGBM_BoosterDumpModel(
self._handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration),
ctypes.c_int(importance_type_int),
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer,
)
)
ret = json.loads(string_buffer.value.decode("utf-8"), object_hook=object_hook)
ret["pandas_categorical"] = json.loads(
json.dumps(
self.pandas_categorical,
default=_json_default_with_numpy,
)
)
return ret return ret
def predict( def predict(
...@@ -4410,7 +4676,7 @@ class Booster: ...@@ -4410,7 +4676,7 @@ class Booster:
pred_contrib: bool = False, pred_contrib: bool = False,
data_has_header: bool = False, data_has_header: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any,
) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]: ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
"""Make a prediction. """Make a prediction.
...@@ -4474,7 +4740,7 @@ class Booster: ...@@ -4474,7 +4740,7 @@ class Booster:
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
data_has_header=data_has_header, data_has_header=data_has_header,
validate_features=validate_features validate_features=validate_features,
) )
def refit( def refit(
...@@ -4486,12 +4752,12 @@ class Booster: ...@@ -4486,12 +4752,12 @@ class Booster:
weight: Optional[_LGBM_WeightType] = None, weight: Optional[_LGBM_WeightType] = None,
group: Optional[_LGBM_GroupType] = None, group: Optional[_LGBM_GroupType] = None,
init_score: Optional[_LGBM_InitScoreType] = None, init_score: Optional[_LGBM_InitScoreType] = None,
feature_name: _LGBM_FeatureNameConfiguration = 'auto', feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
dataset_params: Optional[Dict[str, Any]] = None, dataset_params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True, free_raw_data: bool = True,
validate_features: bool = False, validate_features: bool = False,
**kwargs **kwargs,
) -> "Booster": ) -> "Booster":
"""Refit the existing Booster by new data. """Refit the existing Booster by new data.
...@@ -4574,28 +4840,28 @@ class Booster: ...@@ -4574,28 +4840,28 @@ class Booster:
Refitted Booster. Refitted Booster.
""" """
if self.__set_objective_to_none: if self.__set_objective_to_none:
raise LightGBMError('Cannot refit due to null objective function.') raise LightGBMError("Cannot refit due to null objective function.")
if dataset_params is None: if dataset_params is None:
dataset_params = {} dataset_params = {}
predictor = _InnerPredictor.from_booster( predictor = _InnerPredictor.from_booster(booster=self, pred_parameter=deepcopy(kwargs))
booster=self,
pred_parameter=deepcopy(kwargs)
)
leaf_preds: np.ndarray = predictor.predict( # type: ignore[assignment] leaf_preds: np.ndarray = predictor.predict( # type: ignore[assignment]
data=data, data=data,
start_iteration=-1, start_iteration=-1,
pred_leaf=True, pred_leaf=True,
validate_features=validate_features validate_features=validate_features,
) )
nrow, ncol = leaf_preds.shape nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_int(0) out_is_linear = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetLinear( _safe_call(
self._handle, _LIB.LGBM_BoosterGetLinear(
ctypes.byref(out_is_linear))) self._handle,
ctypes.byref(out_is_linear),
)
)
new_params = _choose_param_value( new_params = _choose_param_value(
main_param_name="linear_tree", main_param_name="linear_tree",
params=self.params, params=self.params,
default_value=None default_value=None,
) )
new_params["linear_tree"] = bool(out_is_linear.value) new_params["linear_tree"] = bool(out_is_linear.value)
new_params.update(dataset_params) new_params.update(dataset_params)
...@@ -4611,19 +4877,25 @@ class Booster: ...@@ -4611,19 +4877,25 @@ class Booster:
params=new_params, params=new_params,
free_raw_data=free_raw_data, free_raw_data=free_raw_data,
) )
new_params['refit_decay_rate'] = decay_rate new_params["refit_decay_rate"] = decay_rate
new_booster = Booster(new_params, train_set) new_booster = Booster(new_params, train_set)
# Copy models # Copy models
_safe_call(_LIB.LGBM_BoosterMerge( _safe_call(
new_booster._handle, _LIB.LGBM_BoosterMerge(
predictor._handle)) new_booster._handle,
predictor._handle,
)
)
leaf_preds = leaf_preds.reshape(-1) leaf_preds = leaf_preds.reshape(-1)
ptr_data, _, _ = _c_int_array(leaf_preds) ptr_data, _, _ = _c_int_array(leaf_preds)
_safe_call(_LIB.LGBM_BoosterRefit( _safe_call(
new_booster._handle, _LIB.LGBM_BoosterRefit(
ptr_data, new_booster._handle,
ctypes.c_int32(nrow), ptr_data,
ctypes.c_int32(ncol))) ctypes.c_int32(nrow),
ctypes.c_int32(ncol),
)
)
new_booster._network = self._network new_booster._network = self._network
return new_booster return new_booster
...@@ -4643,11 +4915,14 @@ class Booster: ...@@ -4643,11 +4915,14 @@ class Booster:
The output of the leaf. The output of the leaf.
""" """
ret = ctypes.c_double(0) ret = ctypes.c_double(0)
_safe_call(_LIB.LGBM_BoosterGetLeafValue( _safe_call(
self._handle, _LIB.LGBM_BoosterGetLeafValue(
ctypes.c_int(tree_id), self._handle,
ctypes.c_int(leaf_id), ctypes.c_int(tree_id),
ctypes.byref(ret))) ctypes.c_int(leaf_id),
ctypes.byref(ret),
)
)
return ret.value return ret.value
def set_leaf_output( def set_leaf_output(
...@@ -4655,7 +4930,7 @@ class Booster: ...@@ -4655,7 +4930,7 @@ class Booster:
tree_id: int, tree_id: int,
leaf_id: int, leaf_id: int,
value: float, value: float,
) -> 'Booster': ) -> "Booster":
"""Set the output of a leaf. """Set the output of a leaf.
.. versionadded:: 4.0.0 .. versionadded:: 4.0.0
...@@ -4679,7 +4954,7 @@ class Booster: ...@@ -4679,7 +4954,7 @@ class Booster:
self._handle, self._handle,
ctypes.c_int(tree_id), ctypes.c_int(tree_id),
ctypes.c_int(leaf_id), ctypes.c_int(leaf_id),
ctypes.c_double(value) ctypes.c_double(value),
) )
) )
return self return self
...@@ -4693,9 +4968,12 @@ class Booster: ...@@ -4693,9 +4968,12 @@ class Booster:
The number of features. The number of features.
""" """
out_num_feature = ctypes.c_int(0) out_num_feature = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumFeature( _safe_call(
self._handle, _LIB.LGBM_BoosterGetNumFeature(
ctypes.byref(out_num_feature))) self._handle,
ctypes.byref(out_num_feature),
)
)
return out_num_feature.value return out_num_feature.value
def feature_name(self) -> List[str]: def feature_name(self) -> List[str]:
...@@ -4713,13 +4991,16 @@ class Booster: ...@@ -4713,13 +4991,16 @@ class Booster:
required_string_buffer_size = ctypes.c_size_t(0) required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)] string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc] ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc]
_safe_call(_LIB.LGBM_BoosterGetFeatureNames( _safe_call(
self._handle, _LIB.LGBM_BoosterGetFeatureNames(
ctypes.c_int(num_feature), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int(num_feature),
ctypes.c_size_t(reserved_string_buffer_size), ctypes.byref(tmp_out_len),
ctypes.byref(required_string_buffer_size), ctypes.c_size_t(reserved_string_buffer_size),
ptr_string_buffers)) ctypes.byref(required_string_buffer_size),
ptr_string_buffers,
)
)
if num_feature != tmp_out_len.value: if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature") raise ValueError("Length of feature names doesn't equal with num_feature")
actual_string_buffer_size = required_string_buffer_size.value actual_string_buffer_size = required_string_buffer_size.value
...@@ -4727,19 +5008,22 @@ class Booster: ...@@ -4727,19 +5008,22 @@ class Booster:
if reserved_string_buffer_size < actual_string_buffer_size: if reserved_string_buffer_size < actual_string_buffer_size:
string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)] string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc] ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc]
_safe_call(_LIB.LGBM_BoosterGetFeatureNames( _safe_call(
self._handle, _LIB.LGBM_BoosterGetFeatureNames(
ctypes.c_int(num_feature), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int(num_feature),
ctypes.c_size_t(actual_string_buffer_size), ctypes.byref(tmp_out_len),
ctypes.byref(required_string_buffer_size), ctypes.c_size_t(actual_string_buffer_size),
ptr_string_buffers)) ctypes.byref(required_string_buffer_size),
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)] ptr_string_buffers,
)
)
return [string_buffers[i].value.decode("utf-8") for i in range(num_feature)]
def feature_importance( def feature_importance(
self, self,
importance_type: str = 'split', importance_type: str = "split",
iteration: Optional[int] = None iteration: Optional[int] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Get feature importances. """Get feature importances.
...@@ -4763,11 +5047,14 @@ class Booster: ...@@ -4763,11 +5047,14 @@ class Booster:
iteration = self.best_iteration iteration = self.best_iteration
importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type] importance_type_int = _FEATURE_IMPORTANCE_TYPE_MAPPER[importance_type]
result = np.empty(self.num_feature(), dtype=np.float64) result = np.empty(self.num_feature(), dtype=np.float64)
_safe_call(_LIB.LGBM_BoosterFeatureImportance( _safe_call(
self._handle, _LIB.LGBM_BoosterFeatureImportance(
ctypes.c_int(iteration), self._handle,
ctypes.c_int(importance_type_int), ctypes.c_int(iteration),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) ctypes.c_int(importance_type_int),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
)
)
if importance_type_int == _C_API_FEATURE_IMPORTANCE_SPLIT: if importance_type_int == _C_API_FEATURE_IMPORTANCE_SPLIT:
return result.astype(np.int32) return result.astype(np.int32)
else: else:
...@@ -4777,7 +5064,7 @@ class Booster: ...@@ -4777,7 +5064,7 @@ class Booster:
self, self,
feature: Union[int, str], feature: Union[int, str],
bins: Optional[Union[int, str]] = None, bins: Optional[Union[int, str]] = None,
xgboost_style: bool = False xgboost_style: bool = False,
) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray, pd_DataFrame]: ) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray, pd_DataFrame]:
"""Get split value histogram for the specified feature. """Get split value histogram for the specified feature.
...@@ -4811,27 +5098,28 @@ class Booster: ...@@ -4811,27 +5098,28 @@ class Booster:
result_array_like : numpy array or pandas DataFrame (if pandas is installed) result_array_like : numpy array or pandas DataFrame (if pandas is installed)
If ``xgboost_style=True``, the histogram of used splitting values for the specified feature. If ``xgboost_style=True``, the histogram of used splitting values for the specified feature.
""" """
def add(root: Dict[str, Any]) -> None: def add(root: Dict[str, Any]) -> None:
"""Recursively add thresholds.""" """Recursively add thresholds."""
if 'split_index' in root: # non-leaf if "split_index" in root: # non-leaf
if feature_names is not None and isinstance(feature, str): if feature_names is not None and isinstance(feature, str):
split_feature = feature_names[root['split_feature']] split_feature = feature_names[root["split_feature"]]
else: else:
split_feature = root['split_feature'] split_feature = root["split_feature"]
if split_feature == feature: if split_feature == feature:
if isinstance(root['threshold'], str): if isinstance(root["threshold"], str):
raise LightGBMError('Cannot compute split value histogram for the categorical feature') raise LightGBMError("Cannot compute split value histogram for the categorical feature")
else: else:
values.append(root['threshold']) values.append(root["threshold"])
add(root['left_child']) add(root["left_child"])
add(root['right_child']) add(root["right_child"])
model = self.dump_model() model = self.dump_model()
feature_names = model.get('feature_names') feature_names = model.get("feature_names")
tree_infos = model['tree_info'] tree_infos = model["tree_info"]
values: List[float] = [] values: List[float] = []
for tree_info in tree_infos: for tree_info in tree_infos:
add(tree_info['tree_structure']) add(tree_info["tree_structure"])
if bins is None or isinstance(bins, int) and xgboost_style: if bins is None or isinstance(bins, int) and xgboost_style:
n_unique = len(np.unique(values)) n_unique = len(np.unique(values))
...@@ -4841,7 +5129,7 @@ class Booster: ...@@ -4841,7 +5129,7 @@ class Booster:
ret = np.column_stack((bin_edges[1:], hist)) ret = np.column_stack((bin_edges[1:], hist))
ret = ret[ret[:, 1] > 0] ret = ret[ret[:, 1] > 0]
if PANDAS_INSTALLED: if PANDAS_INSTALLED:
return pd_DataFrame(ret, columns=['SplitValue', 'Count']) return pd_DataFrame(ret, columns=["SplitValue", "Count"])
else: else:
return ret return ret
else: else:
...@@ -4851,7 +5139,7 @@ class Booster: ...@@ -4851,7 +5139,7 @@ class Booster:
self, self,
data_name: str, data_name: str,
data_idx: int, data_idx: int,
feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]] feval: Optional[Union[_LGBM_CustomEvalFunction, List[_LGBM_CustomEvalFunction]]],
) -> List[_LGBM_BoosterEvalMethodResultType]: ) -> List[_LGBM_BoosterEvalMethodResultType]:
"""Evaluate training or validation data.""" """Evaluate training or validation data."""
if data_idx >= self.__num_dataset: if data_idx >= self.__num_dataset:
...@@ -4861,16 +5149,18 @@ class Booster: ...@@ -4861,16 +5149,18 @@ class Booster:
if self.__num_inner_eval > 0: if self.__num_inner_eval > 0:
result = np.empty(self.__num_inner_eval, dtype=np.float64) result = np.empty(self.__num_inner_eval, dtype=np.float64)
tmp_out_len = ctypes.c_int(0) tmp_out_len = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetEval( _safe_call(
self._handle, _LIB.LGBM_BoosterGetEval(
ctypes.c_int(data_idx), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int(data_idx),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) ctypes.byref(tmp_out_len),
result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
)
)
if tmp_out_len.value != self.__num_inner_eval: if tmp_out_len.value != self.__num_inner_eval:
raise ValueError("Wrong length of eval results") raise ValueError("Wrong length of eval results")
for i in range(self.__num_inner_eval): for i in range(self.__num_inner_eval):
ret.append((data_name, self.__name_inner_eval[i], ret.append((data_name, self.__name_inner_eval[i], result[i], self.__higher_better_inner_eval[i]))
result[i], self.__higher_better_inner_eval[i]))
if callable(feval): if callable(feval):
feval = [feval] feval = [feval]
if feval is not None: if feval is not None:
...@@ -4904,18 +5194,21 @@ class Booster: ...@@ -4904,18 +5194,21 @@ class Booster:
if not self.__is_predicted_cur_iter[data_idx]: if not self.__is_predicted_cur_iter[data_idx]:
tmp_out_len = ctypes.c_int64(0) tmp_out_len = ctypes.c_int64(0)
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) # type: ignore[union-attr] data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) # type: ignore[union-attr]
_safe_call(_LIB.LGBM_BoosterGetPredict( _safe_call(
self._handle, _LIB.LGBM_BoosterGetPredict(
ctypes.c_int(data_idx), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int(data_idx),
data_ptr)) ctypes.byref(tmp_out_len),
data_ptr,
)
)
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): # type: ignore[arg-type] if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): # type: ignore[arg-type]
raise ValueError(f"Wrong length of predict results for data {data_idx}") raise ValueError(f"Wrong length of predict results for data {data_idx}")
self.__is_predicted_cur_iter[data_idx] = True self.__is_predicted_cur_iter[data_idx] = True
result: np.ndarray = self.__inner_predict_buffer[data_idx] # type: ignore[assignment] result: np.ndarray = self.__inner_predict_buffer[data_idx] # type: ignore[assignment]
if self.__num_class > 1: if self.__num_class > 1:
num_data = result.size // self.__num_class num_data = result.size // self.__num_class
result = result.reshape(num_data, self.__num_class, order='F') result = result.reshape(num_data, self.__num_class, order="F")
return result return result
def __get_eval_info(self) -> None: def __get_eval_info(self) -> None:
...@@ -4924,9 +5217,12 @@ class Booster: ...@@ -4924,9 +5217,12 @@ class Booster:
self.__need_reload_eval_info = False self.__need_reload_eval_info = False
out_num_eval = ctypes.c_int(0) out_num_eval = ctypes.c_int(0)
# Get num of inner evals # Get num of inner evals
_safe_call(_LIB.LGBM_BoosterGetEvalCounts( _safe_call(
self._handle, _LIB.LGBM_BoosterGetEvalCounts(
ctypes.byref(out_num_eval))) self._handle,
ctypes.byref(out_num_eval),
)
)
self.__num_inner_eval = out_num_eval.value self.__num_inner_eval = out_num_eval.value
if self.__num_inner_eval > 0: if self.__num_inner_eval > 0:
# Get name of eval metrics # Get name of eval metrics
...@@ -4937,13 +5233,16 @@ class Booster: ...@@ -4937,13 +5233,16 @@ class Booster:
ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(self.__num_inner_eval) ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(self.__num_inner_eval)
] ]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc] ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc]
_safe_call(_LIB.LGBM_BoosterGetEvalNames( _safe_call(
self._handle, _LIB.LGBM_BoosterGetEvalNames(
ctypes.c_int(self.__num_inner_eval), self._handle,
ctypes.byref(tmp_out_len), ctypes.c_int(self.__num_inner_eval),
ctypes.c_size_t(reserved_string_buffer_size), ctypes.byref(tmp_out_len),
ctypes.byref(required_string_buffer_size), ctypes.c_size_t(reserved_string_buffer_size),
ptr_string_buffers)) ctypes.byref(required_string_buffer_size),
ptr_string_buffers,
)
)
if self.__num_inner_eval != tmp_out_len.value: if self.__num_inner_eval != tmp_out_len.value:
raise ValueError("Length of eval names doesn't equal with num_evals") raise ValueError("Length of eval names doesn't equal with num_evals")
actual_string_buffer_size = required_string_buffer_size.value actual_string_buffer_size = required_string_buffer_size.value
...@@ -4952,17 +5251,20 @@ class Booster: ...@@ -4952,17 +5251,20 @@ class Booster:
string_buffers = [ string_buffers = [
ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(self.__num_inner_eval) ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(self.__num_inner_eval)
] ]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers)) # type: ignore[misc] ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(
_safe_call(_LIB.LGBM_BoosterGetEvalNames( *map(ctypes.addressof, string_buffers)
self._handle, ) # type: ignore[misc]
ctypes.c_int(self.__num_inner_eval), _safe_call(
ctypes.byref(tmp_out_len), _LIB.LGBM_BoosterGetEvalNames(
ctypes.c_size_t(actual_string_buffer_size), self._handle,
ctypes.byref(required_string_buffer_size), ctypes.c_int(self.__num_inner_eval),
ptr_string_buffers)) ctypes.byref(tmp_out_len),
self.__name_inner_eval = [ ctypes.c_size_t(actual_string_buffer_size),
string_buffers[i].value.decode('utf-8') for i in range(self.__num_inner_eval) ctypes.byref(required_string_buffer_size),
] ptr_string_buffers,
)
)
self.__name_inner_eval = [string_buffers[i].value.decode("utf-8") for i in range(self.__num_inner_eval)]
self.__higher_better_inner_eval = [ self.__higher_better_inner_eval = [
name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval name.startswith(("auc", "ndcg@", "map@", "average_precision")) for name in self.__name_inner_eval
] ]
...@@ -18,21 +18,21 @@ if TYPE_CHECKING: ...@@ -18,21 +18,21 @@ if TYPE_CHECKING:
from .engine import CVBooster from .engine import CVBooster
__all__ = [ __all__ = [
'EarlyStopException', "EarlyStopException",
'early_stopping', "early_stopping",
'log_evaluation', "log_evaluation",
'record_evaluation', "record_evaluation",
'reset_parameter', "reset_parameter",
] ]
_EvalResultDict = Dict[str, Dict[str, List[Any]]] _EvalResultDict = Dict[str, Dict[str, List[Any]]]
_EvalResultTuple = Union[ _EvalResultTuple = Union[
_LGBM_BoosterEvalMethodResultType, _LGBM_BoosterEvalMethodResultType,
_LGBM_BoosterEvalMethodResultWithStandardDeviationType _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
] ]
_ListOfEvalResultTuples = Union[ _ListOfEvalResultTuples = Union[
List[_LGBM_BoosterEvalMethodResultType], List[_LGBM_BoosterEvalMethodResultType],
List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType] List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType],
] ]
...@@ -95,8 +95,8 @@ class _LogEvaluationCallback: ...@@ -95,8 +95,8 @@ class _LogEvaluationCallback:
def __call__(self, env: CallbackEnv) -> None: def __call__(self, env: CallbackEnv) -> None:
if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0: if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list]) result = "\t".join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}') _log_info(f"[{env.iteration + 1}]\t{result}")
def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback: def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
...@@ -133,7 +133,7 @@ class _RecordEvaluationCallback: ...@@ -133,7 +133,7 @@ class _RecordEvaluationCallback:
self.before_iteration = False self.before_iteration = False
if not isinstance(eval_result, dict): if not isinstance(eval_result, dict):
raise TypeError('eval_result should be a dictionary') raise TypeError("eval_result should be a dictionary")
self.eval_result = eval_result self.eval_result = eval_result
def _init(self, env: CallbackEnv) -> None: def _init(self, env: CallbackEnv) -> None:
...@@ -152,8 +152,8 @@ class _RecordEvaluationCallback: ...@@ -152,8 +152,8 @@ class _RecordEvaluationCallback:
if len(item) == 4: if len(item) == 4:
self.eval_result[data_name].setdefault(eval_name, []) self.eval_result[data_name].setdefault(eval_name, [])
else: else:
self.eval_result[data_name].setdefault(f'{eval_name}-mean', []) self.eval_result[data_name].setdefault(f"{eval_name}-mean", [])
self.eval_result[data_name].setdefault(f'{eval_name}-stdv', []) self.eval_result[data_name].setdefault(f"{eval_name}-stdv", [])
def __call__(self, env: CallbackEnv) -> None: def __call__(self, env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration: if env.iteration == env.begin_iteration:
...@@ -171,8 +171,8 @@ class _RecordEvaluationCallback: ...@@ -171,8 +171,8 @@ class _RecordEvaluationCallback:
data_name, eval_name = item[1].split() data_name, eval_name = item[1].split()
res_mean = item[2] res_mean = item[2]
res_stdv = item[4] # type: ignore[misc] res_stdv = item[4] # type: ignore[misc]
self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean) self.eval_result[data_name][f"{eval_name}-mean"].append(res_mean)
self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv) self.eval_result[data_name][f"{eval_name}-stdv"].append(res_stdv)
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
...@@ -230,8 +230,10 @@ class _ResetParameterCallback: ...@@ -230,8 +230,10 @@ class _ResetParameterCallback:
elif callable(value): elif callable(value):
new_param = value(env.iteration - env.begin_iteration) new_param = value(env.iteration - env.begin_iteration)
else: else:
raise ValueError("Only list and callable values are supported " raise ValueError(
"as a mapping from boosting round index to new parameter value.") "Only list and callable values are supported "
"as a mapping from boosting round index to new parameter value."
)
if new_param != env.params.get(key, None): if new_param != env.params.get(key, None):
new_parameters[key] = new_param new_parameters[key] = new_param
if new_parameters: if new_parameters:
...@@ -276,9 +278,8 @@ class _EarlyStoppingCallback: ...@@ -276,9 +278,8 @@ class _EarlyStoppingCallback:
stopping_rounds: int, stopping_rounds: int,
first_metric_only: bool = False, first_metric_only: bool = False,
verbose: bool = True, verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0 min_delta: Union[float, List[float]] = 0.0,
) -> None: ) -> None:
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0: if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}") raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")
...@@ -298,7 +299,7 @@ class _EarlyStoppingCallback: ...@@ -298,7 +299,7 @@ class _EarlyStoppingCallback:
self.best_iter: List[int] = [] self.best_iter: List[int] = []
self.best_score_list: List[_ListOfEvalResultTuples] = [] self.best_score_list: List[_ListOfEvalResultTuples] = []
self.cmp_op: List[Callable[[float, float], bool]] = [] self.cmp_op: List[Callable[[float, float], bool]] = []
self.first_metric = '' self.first_metric = ""
def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta return curr_score > best_score + delta
...@@ -321,29 +322,24 @@ class _EarlyStoppingCallback: ...@@ -321,29 +322,24 @@ class _EarlyStoppingCallback:
def _init(self, env: CallbackEnv) -> None: def _init(self, env: CallbackEnv) -> None:
if env.evaluation_result_list is None or env.evaluation_result_list == []: if env.evaluation_result_list is None or env.evaluation_result_list == []:
raise ValueError( raise ValueError("For early stopping, at least one dataset and eval metric is required for evaluation")
"For early stopping, at least one dataset and eval metric is required for evaluation"
)
is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting")) is_dart = any(env.params.get(alias, "") == "dart" for alias in _ConfigAliases.get("boosting"))
if is_dart: if is_dart:
self.enabled = False self.enabled = False
_log_warning('Early stopping is not available in dart mode') _log_warning("Early stopping is not available in dart mode")
return return
# validation sets are guaranteed to not be identical to the training data in cv() # validation sets are guaranteed to not be identical to the training data in cv()
if isinstance(env.model, Booster): if isinstance(env.model, Booster):
only_train_set = ( only_train_set = len(env.evaluation_result_list) == 1 and self._is_train_set(
len(env.evaluation_result_list) == 1 ds_name=env.evaluation_result_list[0][0],
and self._is_train_set( eval_name=env.evaluation_result_list[0][1].split(" ")[0],
ds_name=env.evaluation_result_list[0][0], env=env,
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
env=env
)
) )
if only_train_set: if only_train_set:
self.enabled = False self.enabled = False
_log_warning('Only training set found, disabling early stopping.') _log_warning("Only training set found, disabling early stopping.")
return return
if self.verbose: if self.verbose:
...@@ -355,26 +351,26 @@ class _EarlyStoppingCallback: ...@@ -355,26 +351,26 @@ class _EarlyStoppingCallback:
n_datasets = len(env.evaluation_result_list) // n_metrics n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(self.min_delta, list): if isinstance(self.min_delta, list):
if not all(t >= 0 for t in self.min_delta): if not all(t >= 0 for t in self.min_delta):
raise ValueError('Values for early stopping min_delta must be non-negative.') raise ValueError("Values for early stopping min_delta must be non-negative.")
if len(self.min_delta) == 0: if len(self.min_delta) == 0:
if self.verbose: if self.verbose:
_log_info('Disabling min_delta for early stopping.') _log_info("Disabling min_delta for early stopping.")
deltas = [0.0] * n_datasets * n_metrics deltas = [0.0] * n_datasets * n_metrics
elif len(self.min_delta) == 1: elif len(self.min_delta) == 1:
if self.verbose: if self.verbose:
_log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.') _log_info(f"Using {self.min_delta[0]} as min_delta for all metrics.")
deltas = self.min_delta * n_datasets * n_metrics deltas = self.min_delta * n_datasets * n_metrics
else: else:
if len(self.min_delta) != n_metrics: if len(self.min_delta) != n_metrics:
raise ValueError('Must provide a single value for min_delta or as many as metrics.') raise ValueError("Must provide a single value for min_delta or as many as metrics.")
if self.first_metric_only and self.verbose: if self.first_metric_only and self.verbose:
_log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.') _log_info(f"Using only {self.min_delta[0]} as early stopping min_delta.")
deltas = self.min_delta * n_datasets deltas = self.min_delta * n_datasets
else: else:
if self.min_delta < 0: if self.min_delta < 0:
raise ValueError('Early stopping min_delta must be non-negative.') raise ValueError("Early stopping min_delta must be non-negative.")
if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose: if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
_log_info(f'Using {self.min_delta} as min_delta for all metrics.') _log_info(f"Using {self.min_delta} as min_delta for all metrics.")
deltas = [self.min_delta] * n_datasets * n_metrics deltas = [self.min_delta] * n_datasets * n_metrics
# split is needed for "<dataset type> <metric>" case (e.g. "train l1") # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
...@@ -382,18 +378,19 @@ class _EarlyStoppingCallback: ...@@ -382,18 +378,19 @@ class _EarlyStoppingCallback:
for eval_ret, delta in zip(env.evaluation_result_list, deltas): for eval_ret, delta in zip(env.evaluation_result_list, deltas):
self.best_iter.append(0) self.best_iter.append(0)
if eval_ret[3]: # greater is better if eval_ret[3]: # greater is better
self.best_score.append(float('-inf')) self.best_score.append(float("-inf"))
self.cmp_op.append(partial(self._gt_delta, delta=delta)) self.cmp_op.append(partial(self._gt_delta, delta=delta))
else: else:
self.best_score.append(float('inf')) self.best_score.append(float("inf"))
self.cmp_op.append(partial(self._lt_delta, delta=delta)) self.cmp_op.append(partial(self._lt_delta, delta=delta))
def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
if env.iteration == env.end_iteration - 1: if env.iteration == env.end_iteration - 1:
if self.verbose: if self.verbose:
best_score_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]]) best_score_str = "\t".join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
_log_info('Did not meet early stopping. ' _log_info(
f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}') "Did not meet early stopping. " f"Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}"
)
if self.first_metric_only: if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}") _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(self.best_iter[i], self.best_score_list[i]) raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
...@@ -409,7 +406,7 @@ class _EarlyStoppingCallback: ...@@ -409,7 +406,7 @@ class _EarlyStoppingCallback:
"Please report it at https://github.com/microsoft/LightGBM/issues" "Please report it at https://github.com/microsoft/LightGBM/issues"
) )
# self.best_score_list is initialized to an empty list # self.best_score_list is initialized to an empty list
first_time_updating_best_score_list = (self.best_score_list == []) first_time_updating_best_score_list = self.best_score_list == []
for i in range(len(env.evaluation_result_list)): for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] score = env.evaluation_result_list[i][2]
if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]): if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
...@@ -426,12 +423,14 @@ class _EarlyStoppingCallback: ...@@ -426,12 +423,14 @@ class _EarlyStoppingCallback:
if self._is_train_set( if self._is_train_set(
ds_name=env.evaluation_result_list[i][0], ds_name=env.evaluation_result_list[i][0],
eval_name=eval_name_splitted[0], eval_name=eval_name_splitted[0],
env=env env=env,
): ):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train) continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds: elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose: if self.verbose:
eval_result_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]]) eval_result_str = "\t".join(
[_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]]
)
_log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}") _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
if self.first_metric_only: if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}") _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
...@@ -439,7 +438,12 @@ class _EarlyStoppingCallback: ...@@ -439,7 +438,12 @@ class _EarlyStoppingCallback:
self._final_iteration_check(env, eval_name_splitted, i) self._final_iteration_check(env, eval_name_splitted, i)
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback: def early_stopping(
stopping_rounds: int,
first_metric_only: bool = False,
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
) -> _EarlyStoppingCallback:
"""Create a callback that activates early stopping. """Create a callback that activates early stopping.
Activates early stopping. Activates early stopping.
...@@ -473,4 +477,9 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos ...@@ -473,4 +477,9 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
callback : _EarlyStoppingCallback callback : _EarlyStoppingCallback
The callback that activates early stopping. The callback that activates early stopping.
""" """
return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta) return _EarlyStoppingCallback(
stopping_rounds=stopping_rounds,
first_metric_only=first_metric_only,
verbose=verbose,
min_delta=min_delta,
)
...@@ -8,6 +8,7 @@ try: ...@@ -8,6 +8,7 @@ try:
from pandas import DataFrame as pd_DataFrame from pandas import DataFrame as pd_DataFrame
from pandas import Series as pd_Series from pandas import Series as pd_Series
from pandas import concat from pandas import concat
try: try:
from pandas import CategoricalDtype as pd_CategoricalDtype from pandas import CategoricalDtype as pd_CategoricalDtype
except ImportError: except ImportError:
...@@ -40,15 +41,18 @@ except ImportError: ...@@ -40,15 +41,18 @@ except ImportError:
try: try:
from numpy.random import Generator as np_random_Generator from numpy.random import Generator as np_random_Generator
except ImportError: except ImportError:
class np_random_Generator: # type: ignore class np_random_Generator: # type: ignore
"""Dummy class for np.random.Generator.""" """Dummy class for np.random.Generator."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
"""matplotlib""" """matplotlib"""
try: try:
import matplotlib # noqa: F401 import matplotlib # noqa: F401
MATPLOTLIB_INSTALLED = True MATPLOTLIB_INSTALLED = True
except ImportError: except ImportError:
MATPLOTLIB_INSTALLED = False MATPLOTLIB_INSTALLED = False
...@@ -56,6 +60,7 @@ except ImportError: ...@@ -56,6 +60,7 @@ except ImportError:
"""graphviz""" """graphviz"""
try: try:
import graphviz # noqa: F401 import graphviz # noqa: F401
GRAPHVIZ_INSTALLED = True GRAPHVIZ_INSTALLED = True
except ImportError: except ImportError:
GRAPHVIZ_INSTALLED = False GRAPHVIZ_INSTALLED = False
...@@ -63,6 +68,7 @@ except ImportError: ...@@ -63,6 +68,7 @@ except ImportError:
"""datatable""" """datatable"""
try: try:
import datatable import datatable
if hasattr(datatable, "Frame"): if hasattr(datatable, "Frame"):
dt_DataTable = datatable.Frame dt_DataTable = datatable.Frame
else: else:
...@@ -85,6 +91,7 @@ try: ...@@ -85,6 +91,7 @@ try:
from sklearn.utils.class_weight import compute_sample_weight from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
try: try:
from sklearn.exceptions import NotFittedError from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
...@@ -155,6 +162,7 @@ try: ...@@ -155,6 +162,7 @@ try:
from dask.dataframe import DataFrame as dask_DataFrame from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series from dask.dataframe import Series as dask_Series
from dask.distributed import Client, Future, default_client, wait from dask.distributed import Client, Future, default_client, wait
DASK_INSTALLED = True DASK_INSTALLED = True
except ImportError: except ImportError:
DASK_INSTALLED = False DASK_INSTALLED = False
...@@ -195,6 +203,7 @@ except ImportError: ...@@ -195,6 +203,7 @@ except ImportError:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
"""pyarrow""" """pyarrow"""
try: try:
import pyarrow.compute as pa_compute import pyarrow.compute as pa_compute
...@@ -205,6 +214,7 @@ try: ...@@ -205,6 +214,7 @@ try:
from pyarrow.cffi import ffi as arrow_cffi from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_floating as arrow_is_floating from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer from pyarrow.types import is_integer as arrow_is_integer
PYARROW_INSTALLED = True PYARROW_INSTALLED = True
except ImportError: except ImportError:
PYARROW_INSTALLED = False PYARROW_INSTALLED = False
...@@ -266,4 +276,5 @@ except ImportError: ...@@ -266,4 +276,5 @@ except ImportError:
def _LGBMCpuCount(only_physical_cores: bool = True) -> int: def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
return cpu_count() return cpu_count()
__all__: List[str] = [] __all__: List[str] = []
...@@ -51,9 +51,9 @@ from .sklearn import ( ...@@ -51,9 +51,9 @@ from .sklearn import (
) )
__all__ = [ __all__ = [
'DaskLGBMClassifier', "DaskLGBMClassifier",
'DaskLGBMRanker', "DaskLGBMRanker",
'DaskLGBMRegressor', "DaskLGBMRegressor",
] ]
_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series] _DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
...@@ -67,7 +67,7 @@ class _RemoteSocket: ...@@ -67,7 +67,7 @@ class _RemoteSocket:
def acquire(self) -> int: def acquire(self) -> int:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(('', 0)) self.socket.bind(("", 0))
return self.socket.getsockname()[1] return self.socket.getsockname()[1]
def release(self) -> None: def release(self) -> None:
...@@ -153,9 +153,11 @@ def _concat(seq: List[_DaskPart]) -> _DaskPart: ...@@ -153,9 +153,11 @@ def _concat(seq: List[_DaskPart]) -> _DaskPart:
elif isinstance(seq[0], (pd_DataFrame, pd_Series)): elif isinstance(seq[0], (pd_DataFrame, pd_Series)):
return concat(seq, axis=0) return concat(seq, axis=0)
elif isinstance(seq[0], ss.spmatrix): elif isinstance(seq[0], ss.spmatrix):
return ss.vstack(seq, format='csr') return ss.vstack(seq, format="csr")
else: else:
raise TypeError(f'Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got {type(seq[0]).__name__}.') raise TypeError(
f"Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got {type(seq[0]).__name__}."
)
def _remove_list_padding(*args: Any) -> List[List[Any]]: def _remove_list_padding(*args: Any) -> List[List[Any]]:
...@@ -186,41 +188,41 @@ def _train_part( ...@@ -186,41 +188,41 @@ def _train_part(
return_model: bool, return_model: bool,
time_out: int, time_out: int,
remote_socket: _RemoteSocket, remote_socket: _RemoteSocket,
**kwargs: Any **kwargs: Any,
) -> Optional[LGBMModel]: ) -> Optional[LGBMModel]:
network_params = { network_params = {
'machines': machines, "machines": machines,
'local_listen_port': local_listen_port, "local_listen_port": local_listen_port,
'time_out': time_out, "time_out": time_out,
'num_machines': num_machines "num_machines": num_machines,
} }
params.update(network_params) params.update(network_params)
is_ranker = issubclass(model_factory, LGBMRanker) is_ranker = issubclass(model_factory, LGBMRanker)
# Concatenate many parts into one # Concatenate many parts into one
data = _concat([x['data'] for x in list_of_parts]) data = _concat([x["data"] for x in list_of_parts])
label = _concat([x['label'] for x in list_of_parts]) label = _concat([x["label"] for x in list_of_parts])
if 'weight' in list_of_parts[0]: if "weight" in list_of_parts[0]:
weight = _concat([x['weight'] for x in list_of_parts]) weight = _concat([x["weight"] for x in list_of_parts])
else: else:
weight = None weight = None
if 'group' in list_of_parts[0]: if "group" in list_of_parts[0]:
group = _concat([x['group'] for x in list_of_parts]) group = _concat([x["group"] for x in list_of_parts])
else: else:
group = None group = None
if 'init_score' in list_of_parts[0]: if "init_score" in list_of_parts[0]:
init_score = _concat([x['init_score'] for x in list_of_parts]) init_score = _concat([x["init_score"] for x in list_of_parts])
else: else:
init_score = None init_score = None
# construct local eval_set data. # construct local eval_set data.
n_evals = max(len(x.get('eval_set', [])) for x in list_of_parts) n_evals = max(len(x.get("eval_set", [])) for x in list_of_parts)
eval_names = kwargs.pop('eval_names', None) eval_names = kwargs.pop("eval_names", None)
eval_class_weight = kwargs.get('eval_class_weight') eval_class_weight = kwargs.get("eval_class_weight")
local_eval_set = None local_eval_set = None
local_eval_names = None local_eval_names = None
local_eval_sample_weight = None local_eval_sample_weight = None
...@@ -228,8 +230,8 @@ def _train_part( ...@@ -228,8 +230,8 @@ def _train_part(
local_eval_group = None local_eval_group = None
if n_evals: if n_evals:
has_eval_sample_weight = any(x.get('eval_sample_weight') is not None for x in list_of_parts) has_eval_sample_weight = any(x.get("eval_sample_weight") is not None for x in list_of_parts)
has_eval_init_score = any(x.get('eval_init_score') is not None for x in list_of_parts) has_eval_init_score = any(x.get("eval_init_score") is not None for x in list_of_parts)
local_eval_set = [] local_eval_set = []
evals_result_names = [] evals_result_names = []
...@@ -251,7 +253,7 @@ def _train_part( ...@@ -251,7 +253,7 @@ def _train_part(
init_score_e = [] init_score_e = []
g_e = [] g_e = []
for part in list_of_parts: for part in list_of_parts:
if not part.get('eval_set'): if not part.get("eval_set"):
continue continue
# require that eval_name exists in evaluated result data in case dropped due to padding. # require that eval_name exists in evaluated result data in case dropped due to padding.
...@@ -259,12 +261,12 @@ def _train_part( ...@@ -259,12 +261,12 @@ def _train_part(
if eval_names: if eval_names:
evals_result_name = eval_names[i] evals_result_name = eval_names[i]
else: else:
evals_result_name = f'valid_{i}' evals_result_name = f"valid_{i}"
eval_set = part['eval_set'][i] eval_set = part["eval_set"][i]
if eval_set is _DatasetNames.TRAINSET: if eval_set is _DatasetNames.TRAINSET:
x_e.append(part['data']) x_e.append(part["data"])
y_e.append(part['label']) y_e.append(part["label"])
else: else:
x_e.extend(eval_set[0]) x_e.extend(eval_set[0])
y_e.extend(eval_set[1]) y_e.extend(eval_set[1])
...@@ -272,24 +274,24 @@ def _train_part( ...@@ -272,24 +274,24 @@ def _train_part(
if evals_result_name not in evals_result_names: if evals_result_name not in evals_result_names:
evals_result_names.append(evals_result_name) evals_result_names.append(evals_result_name)
eval_weight = part.get('eval_sample_weight') eval_weight = part.get("eval_sample_weight")
if eval_weight: if eval_weight:
if eval_weight[i] is _DatasetNames.SAMPLE_WEIGHT: if eval_weight[i] is _DatasetNames.SAMPLE_WEIGHT:
w_e.append(part['weight']) w_e.append(part["weight"])
else: else:
w_e.extend(eval_weight[i]) w_e.extend(eval_weight[i])
eval_init_score = part.get('eval_init_score') eval_init_score = part.get("eval_init_score")
if eval_init_score: if eval_init_score:
if eval_init_score[i] is _DatasetNames.INIT_SCORE: if eval_init_score[i] is _DatasetNames.INIT_SCORE:
init_score_e.append(part['init_score']) init_score_e.append(part["init_score"])
else: else:
init_score_e.extend(eval_init_score[i]) init_score_e.extend(eval_init_score[i])
eval_group = part.get('eval_group') eval_group = part.get("eval_group")
if eval_group: if eval_group:
if eval_group[i] is _DatasetNames.GROUP: if eval_group[i] is _DatasetNames.GROUP:
g_e.append(part['group']) g_e.append(part["group"])
else: else:
g_e.extend(eval_group[i]) g_e.extend(eval_group[i])
...@@ -313,7 +315,7 @@ def _train_part( ...@@ -313,7 +315,7 @@ def _train_part(
if eval_names: if eval_names:
local_eval_names = [eval_names[i] for i in eval_component_idx] local_eval_names = [eval_names[i] for i in eval_component_idx]
if eval_class_weight: if eval_class_weight:
kwargs['eval_class_weight'] = [eval_class_weight[i] for i in eval_component_idx] kwargs["eval_class_weight"] = [eval_class_weight[i] for i in eval_component_idx]
model = model_factory(**params) model = model_factory(**params)
if remote_socket is not None: if remote_socket is not None:
...@@ -331,7 +333,7 @@ def _train_part( ...@@ -331,7 +333,7 @@ def _train_part(
eval_init_score=local_eval_init_score, eval_init_score=local_eval_init_score,
eval_group=local_eval_group, eval_group=local_eval_group,
eval_names=local_eval_names, eval_names=local_eval_names,
**kwargs **kwargs,
) )
else: else:
model.fit( model.fit(
...@@ -343,7 +345,7 @@ def _train_part( ...@@ -343,7 +345,7 @@ def _train_part(
eval_sample_weight=local_eval_sample_weight, eval_sample_weight=local_eval_sample_weight,
eval_init_score=local_eval_init_score, eval_init_score=local_eval_init_score,
eval_names=local_eval_names, eval_names=local_eval_names,
**kwargs **kwargs,
) )
finally: finally:
...@@ -389,7 +391,9 @@ def _machines_to_worker_map(machines: str, worker_addresses: Iterable[str]) -> D ...@@ -389,7 +391,9 @@ def _machines_to_worker_map(machines: str, worker_addresses: Iterable[str]) -> D
machine_addresses = machines.split(",") machine_addresses = machines.split(",")
if len(set(machine_addresses)) != len(machine_addresses): if len(set(machine_addresses)) != len(machine_addresses):
raise ValueError(f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination.") raise ValueError(
f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination."
)
machine_to_port = defaultdict(set) machine_to_port = defaultdict(set)
for address in machine_addresses: for address in machine_addresses:
...@@ -423,7 +427,7 @@ def _train( ...@@ -423,7 +427,7 @@ def _train(
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None, eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None,
**kwargs: Any **kwargs: Any,
) -> LGBMModel: ) -> LGBMModel:
"""Inner train routine. """Inner train routine.
...@@ -512,36 +516,34 @@ def _train( ...@@ -512,36 +516,34 @@ def _train(
params = deepcopy(params) params = deepcopy(params)
# capture whether local_listen_port or its aliases were provided # capture whether local_listen_port or its aliases were provided
listen_port_in_params = any( listen_port_in_params = any(alias in params for alias in _ConfigAliases.get("local_listen_port"))
alias in params for alias in _ConfigAliases.get("local_listen_port")
)
# capture whether machines or its aliases were provided # capture whether machines or its aliases were provided
machines_in_params = any( machines_in_params = any(alias in params for alias in _ConfigAliases.get("machines"))
alias in params for alias in _ConfigAliases.get("machines")
)
params = _choose_param_value( params = _choose_param_value(
main_param_name="tree_learner", main_param_name="tree_learner",
params=params, params=params,
default_value="data" default_value="data",
) )
allowed_tree_learners = { allowed_tree_learners = {
'data', "data",
'data_parallel', "data_parallel",
'feature', "feature",
'feature_parallel', "feature_parallel",
'voting', "voting",
'voting_parallel' "voting_parallel",
} }
if params["tree_learner"] not in allowed_tree_learners: if params["tree_learner"] not in allowed_tree_learners:
_log_warning(f'Parameter tree_learner set to {params["tree_learner"]}, which is not allowed. Using "data" as default') _log_warning(
params['tree_learner'] = 'data' f'Parameter tree_learner set to {params["tree_learner"]}, which is not allowed. Using "data" as default'
)
params["tree_learner"] = "data"
# Some passed-in parameters can be removed: # Some passed-in parameters can be removed:
# * 'num_machines': set automatically from Dask worker list # * 'num_machines': set automatically from Dask worker list
# * 'num_threads': overridden to match nthreads on each Dask process # * 'num_threads': overridden to match nthreads on each Dask process
for param_alias in _ConfigAliases.get('num_machines', 'num_threads'): for param_alias in _ConfigAliases.get("num_machines", "num_threads"):
if param_alias in params: if param_alias in params:
_log_warning(f"Parameter {param_alias} will be ignored.") _log_warning(f"Parameter {param_alias} will be ignored.")
params.pop(param_alias) params.pop(param_alias)
...@@ -549,23 +551,23 @@ def _train( ...@@ -549,23 +551,23 @@ def _train(
# Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality # Split arrays/dataframes into parts. Arrange parts into dicts to enforce co-locality
data_parts = _split_to_parts(data=data, is_matrix=True) data_parts = _split_to_parts(data=data, is_matrix=True)
label_parts = _split_to_parts(data=label, is_matrix=False) label_parts = _split_to_parts(data=label, is_matrix=False)
parts = [{'data': x, 'label': y} for (x, y) in zip(data_parts, label_parts)] parts = [{"data": x, "label": y} for (x, y) in zip(data_parts, label_parts)]
n_parts = len(parts) n_parts = len(parts)
if sample_weight is not None: if sample_weight is not None:
weight_parts = _split_to_parts(data=sample_weight, is_matrix=False) weight_parts = _split_to_parts(data=sample_weight, is_matrix=False)
for i in range(n_parts): for i in range(n_parts):
parts[i]['weight'] = weight_parts[i] parts[i]["weight"] = weight_parts[i]
if group is not None: if group is not None:
group_parts = _split_to_parts(data=group, is_matrix=False) group_parts = _split_to_parts(data=group, is_matrix=False)
for i in range(n_parts): for i in range(n_parts):
parts[i]['group'] = group_parts[i] parts[i]["group"] = group_parts[i]
if init_score is not None: if init_score is not None:
init_score_parts = _split_to_parts(data=init_score, is_matrix=False) init_score_parts = _split_to_parts(data=init_score, is_matrix=False)
for i in range(n_parts): for i in range(n_parts):
parts[i]['init_score'] = init_score_parts[i] parts[i]["init_score"] = init_score_parts[i]
# evals_set will to be re-constructed into smaller lists of (X, y) tuples, where # evals_set will to be re-constructed into smaller lists of (X, y) tuples, where
# X and y are each delayed sub-lists of original eval dask Collections. # X and y are each delayed sub-lists of original eval dask Collections.
...@@ -575,47 +577,16 @@ def _train( ...@@ -575,47 +577,16 @@ def _train(
n_largest_eval_parts = max(x[0].npartitions for x in eval_set) n_largest_eval_parts = max(x[0].npartitions for x in eval_set)
eval_sets: Dict[ eval_sets: Dict[
int, int, List[Union[_DatasetNames, Tuple[List[Optional[_DaskMatrixLike]], List[Optional[_DaskVectorLike]]]]]
List[
Union[
_DatasetNames,
Tuple[
List[Optional[_DaskMatrixLike]],
List[Optional[_DaskVectorLike]]
]
]
]
] = defaultdict(list) ] = defaultdict(list)
if eval_sample_weight: if eval_sample_weight:
eval_sample_weights: Dict[ eval_sample_weights: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskVectorLike]]]]] = defaultdict(
int, list
List[ )
Union[
_DatasetNames,
List[Optional[_DaskVectorLike]]
]
]
] = defaultdict(list)
if eval_group: if eval_group:
eval_groups: Dict[ eval_groups: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskVectorLike]]]]] = defaultdict(list)
int,
List[
Union[
_DatasetNames,
List[Optional[_DaskVectorLike]]
]
]
] = defaultdict(list)
if eval_init_score: if eval_init_score:
eval_init_scores: Dict[ eval_init_scores: Dict[int, List[Union[_DatasetNames, List[Optional[_DaskMatrixLike]]]]] = defaultdict(list)
int,
List[
Union[
_DatasetNames,
List[Optional[_DaskMatrixLike]]
]
]
] = defaultdict(list)
for i, (X_eval, y_eval) in enumerate(eval_set): for i, (X_eval, y_eval) in enumerate(eval_set):
n_this_eval_parts = X_eval.npartitions n_this_eval_parts = X_eval.npartitions
...@@ -704,13 +675,13 @@ def _train( ...@@ -704,13 +675,13 @@ def _train(
# assign sub-eval_set components to worker parts. # assign sub-eval_set components to worker parts.
for parts_idx, e_set in eval_sets.items(): for parts_idx, e_set in eval_sets.items():
parts[parts_idx]['eval_set'] = e_set parts[parts_idx]["eval_set"] = e_set
if eval_sample_weight: if eval_sample_weight:
parts[parts_idx]['eval_sample_weight'] = eval_sample_weights[parts_idx] parts[parts_idx]["eval_sample_weight"] = eval_sample_weights[parts_idx]
if eval_init_score: if eval_init_score:
parts[parts_idx]['eval_init_score'] = eval_init_scores[parts_idx] parts[parts_idx]["eval_init_score"] = eval_init_scores[parts_idx]
if eval_group: if eval_group:
parts[parts_idx]['eval_group'] = eval_groups[parts_idx] parts[parts_idx]["eval_group"] = eval_groups[parts_idx]
# Start computation in the background # Start computation in the background
parts = list(map(delayed, parts)) parts = list(map(delayed, parts))
...@@ -718,7 +689,7 @@ def _train( ...@@ -718,7 +689,7 @@ def _train(
wait(parts) wait(parts)
for part in parts: for part in parts:
if part.status == 'error': # type: ignore if part.status == "error": # type: ignore
# trigger error locally # trigger error locally
return part # type: ignore[return-value] return part # type: ignore[return-value]
...@@ -735,7 +706,7 @@ def _train( ...@@ -735,7 +706,7 @@ def _train(
for worker in worker_map: for worker in worker_map:
has_eval_set = False has_eval_set = False
for part in worker_map[worker]: for part in worker_map[worker]:
if 'eval_set' in part.result(): # type: ignore[attr-defined] if "eval_set" in part.result(): # type: ignore[attr-defined]
has_eval_set = True has_eval_set = True
break break
...@@ -747,13 +718,13 @@ def _train( ...@@ -747,13 +718,13 @@ def _train(
# assign general validation set settings to fit kwargs. # assign general validation set settings to fit kwargs.
if eval_names: if eval_names:
kwargs['eval_names'] = eval_names kwargs["eval_names"] = eval_names
if eval_class_weight: if eval_class_weight:
kwargs['eval_class_weight'] = eval_class_weight kwargs["eval_class_weight"] = eval_class_weight
if eval_metric: if eval_metric:
kwargs['eval_metric'] = eval_metric kwargs["eval_metric"] = eval_metric
if eval_at: if eval_at:
kwargs['eval_at'] = eval_at kwargs["eval_at"] = eval_at
master_worker = next(iter(worker_map)) master_worker = next(iter(worker_map))
worker_ncores = client.ncores() worker_ncores = client.ncores()
...@@ -763,14 +734,14 @@ def _train( ...@@ -763,14 +734,14 @@ def _train(
params = _choose_param_value( params = _choose_param_value(
main_param_name="local_listen_port", main_param_name="local_listen_port",
params=params, params=params,
default_value=12400 default_value=12400,
) )
local_listen_port = params.pop("local_listen_port") local_listen_port = params.pop("local_listen_port")
params = _choose_param_value( params = _choose_param_value(
main_param_name="machines", main_param_name="machines",
params=params, params=params,
default_value=None default_value=None,
) )
machines = params.pop("machines") machines = params.pop("machines")
...@@ -781,7 +752,7 @@ def _train( ...@@ -781,7 +752,7 @@ def _train(
_log_info("Using passed-in 'machines' parameter") _log_info("Using passed-in 'machines' parameter")
worker_address_to_port = _machines_to_worker_map( worker_address_to_port = _machines_to_worker_map(
machines=machines, machines=machines,
worker_addresses=worker_addresses worker_addresses=worker_addresses,
) )
else: else:
if listen_port_in_params: if listen_port_in_params:
...@@ -795,19 +766,16 @@ def _train( ...@@ -795,19 +766,16 @@ def _train(
) )
raise LightGBMError(msg) raise LightGBMError(msg)
worker_address_to_port = { worker_address_to_port = {address: local_listen_port for address in worker_addresses}
address: local_listen_port
for address in worker_addresses
}
else: else:
_log_info("Finding random open ports for workers") _log_info("Finding random open ports for workers")
worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(client, list(worker_map.keys())) worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(
client, list(worker_map.keys())
)
machines = ','.join([ machines = ",".join(
f'{urlparse(worker_address).hostname}:{port}' [f"{urlparse(worker_address).hostname}:{port}" for worker_address, port in worker_address_to_port.items()]
for worker_address, port )
in worker_address_to_port.items()
])
num_machines = len(worker_address_to_port) num_machines = len(worker_address_to_port)
...@@ -823,18 +791,18 @@ def _train( ...@@ -823,18 +791,18 @@ def _train(
client.submit( client.submit(
_train_part, _train_part,
model_factory=model_factory, model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]}, params={**params, "num_threads": worker_ncores[worker]},
list_of_parts=list_of_parts, list_of_parts=list_of_parts,
machines=machines, machines=machines,
local_listen_port=worker_address_to_port[worker], local_listen_port=worker_address_to_port[worker],
num_machines=num_machines, num_machines=num_machines,
time_out=params.get('time_out', 120), time_out=params.get("time_out", 120),
remote_socket=worker_to_socket_future.get(worker, None), remote_socket=worker_to_socket_future.get(worker, None),
return_model=(worker == master_worker), return_model=(worker == master_worker),
workers=[worker], workers=[worker],
allow_other_workers=False, allow_other_workers=False,
pure=False, pure=False,
**kwargs **kwargs,
) )
for worker, list_of_parts in worker_map.items() for worker, list_of_parts in worker_map.items()
] ]
...@@ -848,14 +816,14 @@ def _train( ...@@ -848,14 +816,14 @@ def _train(
# on the Dask cluster you're connected to and which workers have pieces of # on the Dask cluster you're connected to and which workers have pieces of
# the training data # the training data
if not listen_port_in_params: if not listen_port_in_params:
for param in _ConfigAliases.get('local_listen_port'): for param in _ConfigAliases.get("local_listen_port"):
model._other_params.pop(param, None) model._other_params.pop(param, None)
if not machines_in_params: if not machines_in_params:
for param in _ConfigAliases.get('machines'): for param in _ConfigAliases.get("machines"):
model._other_params.pop(param, None) model._other_params.pop(param, None)
for param in _ConfigAliases.get('num_machines', 'timeout'): for param in _ConfigAliases.get("num_machines", "timeout"):
model._other_params.pop(param, None) model._other_params.pop(param, None)
return model return model
...@@ -868,9 +836,8 @@ def _predict_part( ...@@ -868,9 +836,8 @@ def _predict_part(
pred_proba: bool, pred_proba: bool,
pred_leaf: bool, pred_leaf: bool,
pred_contrib: bool, pred_contrib: bool,
**kwargs: Any **kwargs: Any,
) -> _DaskPart: ) -> _DaskPart:
result: _DaskPart result: _DaskPart
if part.shape[0] == 0: if part.shape[0] == 0:
result = np.array([]) result = np.array([])
...@@ -880,7 +847,7 @@ def _predict_part( ...@@ -880,7 +847,7 @@ def _predict_part(
raw_score=raw_score, raw_score=raw_score,
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
**kwargs **kwargs,
) )
else: else:
result = model.predict( result = model.predict(
...@@ -888,7 +855,7 @@ def _predict_part( ...@@ -888,7 +855,7 @@ def _predict_part(
raw_score=raw_score, raw_score=raw_score,
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
**kwargs **kwargs,
) )
# dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series # dask.DataFrame.map_partitions() expects each call to return a pandas DataFrame or Series
...@@ -896,7 +863,7 @@ def _predict_part( ...@@ -896,7 +863,7 @@ def _predict_part(
if len(result.shape) == 2: if len(result.shape) == 2:
result = pd_DataFrame(result, index=part.index) result = pd_DataFrame(result, index=part.index)
else: else:
result = pd_Series(result, index=part.index, name='predictions') result = pd_Series(result, index=part.index, name="predictions")
return result return result
...@@ -910,7 +877,7 @@ def _predict( ...@@ -910,7 +877,7 @@ def _predict(
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
dtype: _PredictionDtype = np.float32, dtype: _PredictionDtype = np.float32,
**kwargs: Any **kwargs: Any,
) -> Union[dask_Array, List[dask_Array]]: ) -> Union[dask_Array, List[dask_Array]]:
"""Inner predict routine. """Inner predict routine.
...@@ -943,7 +910,7 @@ def _predict( ...@@ -943,7 +910,7 @@ def _predict(
If ``pred_contrib=True``, the feature contributions for each sample. If ``pred_contrib=True``, the feature contributions for each sample.
""" """
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') raise LightGBMError("dask, pandas and scikit-learn are required for lightgbm.dask")
if isinstance(data, dask_DataFrame): if isinstance(data, dask_DataFrame):
return data.map_partitions( return data.map_partitions(
_predict_part, _predict_part,
...@@ -952,19 +919,14 @@ def _predict( ...@@ -952,19 +919,14 @@ def _predict(
pred_proba=pred_proba, pred_proba=pred_proba,
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
**kwargs **kwargs,
).values ).values
elif isinstance(data, dask_Array): elif isinstance(data, dask_Array):
# for multi-class classification with sparse matrices, pred_contrib predictions # for multi-class classification with sparse matrices, pred_contrib predictions
# are returned as a list of sparse matrices (one per class) # are returned as a list of sparse matrices (one per class)
num_classes = model._n_classes num_classes = model._n_classes
if ( if num_classes > 2 and pred_contrib and isinstance(data._meta, ss.spmatrix):
num_classes > 2
and pred_contrib
and isinstance(data._meta, ss.spmatrix)
):
predict_function = partial( predict_function = partial(
_predict_part, _predict_part,
model=model, model=model,
...@@ -972,7 +934,7 @@ def _predict( ...@@ -972,7 +934,7 @@ def _predict(
pred_proba=pred_proba, pred_proba=pred_proba,
pred_leaf=False, pred_leaf=False,
pred_contrib=True, pred_contrib=True,
**kwargs **kwargs,
) )
delayed_chunks = data.to_delayed() delayed_chunks = data.to_delayed()
...@@ -999,16 +961,16 @@ def _predict( ...@@ -999,16 +961,16 @@ def _predict(
part = dask_array_from_delayed( part = dask_array_from_delayed(
value=_extract(partition, i), value=_extract(partition, i),
shape=(nrows_per_chunk[j], num_cols), shape=(nrows_per_chunk[j], num_cols),
meta=pred_meta meta=pred_meta,
) )
out[i].append(part) out[i].append(part)
# by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix # by default, dask.array.concatenate() concatenates sparse arrays into a COO matrix
# the code below is used instead to ensure that the sparse type is preserved during concatentation # the code below is used instead to ensure that the sparse type is preserved during concatentation
if isinstance(pred_meta, ss.csr_matrix): if isinstance(pred_meta, ss.csr_matrix):
concat_fn = partial(ss.vstack, format='csr') concat_fn = partial(ss.vstack, format="csr")
elif isinstance(pred_meta, ss.csc_matrix): elif isinstance(pred_meta, ss.csc_matrix):
concat_fn = partial(ss.vstack, format='csc') concat_fn = partial(ss.vstack, format="csc")
else: else:
concat_fn = ss.vstack concat_fn = ss.vstack
...@@ -1020,7 +982,7 @@ def _predict( ...@@ -1020,7 +982,7 @@ def _predict(
dask_array_from_delayed( dask_array_from_delayed(
value=delayed(concat_fn)(out[i]), value=delayed(concat_fn)(out[i]),
shape=(data.shape[0], num_cols), shape=(data.shape[0], num_cols),
meta=pred_meta meta=pred_meta,
) )
) )
...@@ -1042,7 +1004,7 @@ def _predict( ...@@ -1042,7 +1004,7 @@ def _predict(
if len(pred_row.shape) > 1: if len(pred_row.shape) > 1:
chunks += (pred_row.shape[1],) chunks += (pred_row.shape[1],)
else: else:
map_blocks_kwargs['drop_axis'] = 1 map_blocks_kwargs["drop_axis"] = 1
return data.map_blocks( return data.map_blocks(
predict_fn, predict_fn,
chunks=chunks, chunks=chunks,
...@@ -1051,11 +1013,10 @@ def _predict( ...@@ -1051,11 +1013,10 @@ def _predict(
**map_blocks_kwargs, **map_blocks_kwargs,
) )
else: else:
raise TypeError(f'Data must be either Dask Array or Dask DataFrame. Got {type(data).__name__}.') raise TypeError(f"Data must be either Dask Array or Dask DataFrame. Got {type(data).__name__}.")
class _DaskLGBMModel: class _DaskLGBMModel:
@property @property
def client_(self) -> Client: def client_(self) -> Client:
""":obj:`dask.distributed.Client`: Dask client. """:obj:`dask.distributed.Client`: Dask client.
...@@ -1064,7 +1025,7 @@ class _DaskLGBMModel: ...@@ -1064,7 +1025,7 @@ class _DaskLGBMModel:
with ``model.set_params(client=client)``. with ``model.set_params(client=client)``.
""" """
if not getattr(self, "fitted_", False): if not getattr(self, "fitted_", False):
raise LGBMNotFittedError('Cannot access property client_ before calling fit().') raise LGBMNotFittedError("Cannot access property client_ before calling fit().")
return _get_dask_client(client=self.client) return _get_dask_client(client=self.client)
...@@ -1093,12 +1054,12 @@ class _DaskLGBMModel: ...@@ -1093,12 +1054,12 @@ class _DaskLGBMModel:
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None, eval_at: Optional[Union[List[int], Tuple[int, ...]]] = None,
**kwargs: Any **kwargs: Any,
) -> "_DaskLGBMModel": ) -> "_DaskLGBMModel":
if not DASK_INSTALLED: if not DASK_INSTALLED:
raise LightGBMError('dask is required for lightgbm.dask') raise LightGBMError("dask is required for lightgbm.dask")
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') raise LightGBMError("dask, pandas and scikit-learn are required for lightgbm.dask")
params = self.get_params(True) # type: ignore[attr-defined] params = self.get_params(True) # type: ignore[attr-defined]
params.pop("client", None) params.pop("client", None)
...@@ -1120,7 +1081,7 @@ class _DaskLGBMModel: ...@@ -1120,7 +1081,7 @@ class _DaskLGBMModel:
eval_group=eval_group, eval_group=eval_group,
eval_metric=eval_metric, eval_metric=eval_metric,
eval_at=eval_at, eval_at=eval_at,
**kwargs **kwargs,
) )
self.set_params(**model.get_params()) # type: ignore[attr-defined] self.set_params(**model.get_params()) # type: ignore[attr-defined]
...@@ -1137,7 +1098,10 @@ class _DaskLGBMModel: ...@@ -1137,7 +1098,10 @@ class _DaskLGBMModel:
return model return model
@staticmethod @staticmethod
def _lgb_dask_copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None: def _lgb_dask_copy_extra_params(
source: Union["_DaskLGBMModel", LGBMModel],
dest: Union["_DaskLGBMModel", LGBMModel],
) -> None:
params = source.get_params() # type: ignore[union-attr] params = source.get_params() # type: ignore[union-attr]
attributes = source.__dict__ attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys()) extra_param_names = set(attributes.keys()).difference(params.keys())
...@@ -1150,7 +1114,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1150,7 +1114,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
def __init__( def __init__(
self, self,
boosting_type: str = 'gbdt', boosting_type: str = "gbdt",
num_leaves: int = 31, num_leaves: int = 31,
max_depth: int = -1, max_depth: int = -1,
learning_rate: float = 0.1, learning_rate: float = 0.1,
...@@ -1158,19 +1122,19 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1158,19 +1122,19 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.0,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
min_child_samples: int = 20, min_child_samples: int = 20,
subsample: float = 1., subsample: float = 1.0,
subsample_freq: int = 0, subsample_freq: int = 0,
colsample_bytree: float = 1., colsample_bytree: float = 1.0,
reg_alpha: float = 0., reg_alpha: float = 0.0,
reg_lambda: float = 0., reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None, random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = 'split', importance_type: str = "split",
client: Optional[Client] = None, client: Optional[Client] = None,
**kwargs: Any **kwargs: Any,
): ):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" """Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self.client = client self.client = client
...@@ -1194,11 +1158,11 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1194,11 +1158,11 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
random_state=random_state, random_state=random_state,
n_jobs=n_jobs, n_jobs=n_jobs,
importance_type=importance_type, importance_type=importance_type,
**kwargs **kwargs,
) )
_base_doc = LGBMClassifier.__init__.__doc__ _base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore
__init__.__doc__ = f""" __init__.__doc__ = f"""
{_before_kwargs}client : dask.distributed.Client or None, optional (default=None) {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
{' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled. {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
...@@ -1220,7 +1184,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1220,7 +1184,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
eval_class_weight: Optional[List[Union[dict, str]]] = None, eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None, eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
**kwargs: Any **kwargs: Any,
) -> "DaskLGBMClassifier": ) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" """Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
self._lgb_dask_fit( self._lgb_dask_fit(
...@@ -1235,7 +1199,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1235,7 +1199,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
eval_class_weight=eval_class_weight, eval_class_weight=eval_class_weight,
eval_init_score=eval_init_score, eval_init_score=eval_init_score,
eval_metric=eval_metric, eval_metric=eval_metric,
**kwargs **kwargs,
) )
return self return self
...@@ -1247,15 +1211,13 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1247,15 +1211,13 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
group_shape="Dask Array or Dask Series or None, optional (default=None)", group_shape="Dask Array or Dask Series or None, optional (default=None)",
eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)", eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
eval_init_score_shape="list of Dask Array, Dask Series or Dask DataFrame (for multi-class task), or None, optional (default=None)", eval_init_score_shape="list of Dask Array, Dask Series or Dask DataFrame (for multi-class task), or None, optional (default=None)",
eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)" eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
) )
# DaskLGBMClassifier does not support group, eval_group. # DaskLGBMClassifier does not support group, eval_group.
_base_doc = (_base_doc[:_base_doc.find('group :')] _base_doc = _base_doc[: _base_doc.find("group :")] + _base_doc[_base_doc.find("eval_set :") :]
+ _base_doc[_base_doc.find('eval_set :'):])
_base_doc = (_base_doc[:_base_doc.find('eval_group :')] _base_doc = _base_doc[: _base_doc.find("eval_group :")] + _base_doc[_base_doc.find("eval_metric :") :]
+ _base_doc[_base_doc.find('eval_metric :'):])
# DaskLGBMClassifier support for callbacks and init_model is not tested # DaskLGBMClassifier support for callbacks and init_model is not tested
fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs
...@@ -1278,7 +1240,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1278,7 +1240,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any,
) -> dask_Array: ) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" """Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
return _predict( return _predict(
...@@ -1292,7 +1254,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1292,7 +1254,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
validate_features=validate_features, validate_features=validate_features,
**kwargs **kwargs,
) )
predict.__doc__ = _lgbmmodel_doc_predict.format( predict.__doc__ = _lgbmmodel_doc_predict.format(
...@@ -1301,7 +1263,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1301,7 +1263,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
output_name="predicted_result", output_name="predicted_result",
predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]", predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]" X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]",
) )
def predict_proba( def predict_proba(
...@@ -1313,7 +1275,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1313,7 +1275,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any,
) -> dask_Array: ) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict( return _predict(
...@@ -1327,7 +1289,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1327,7 +1289,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
validate_features=validate_features, validate_features=validate_features,
**kwargs **kwargs,
) )
predict_proba.__doc__ = _lgbmmodel_doc_predict.format( predict_proba.__doc__ = _lgbmmodel_doc_predict.format(
...@@ -1336,7 +1298,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -1336,7 +1298,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
output_name="predicted_probability", output_name="predicted_probability",
predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]", predicted_result_shape="Dask Array of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]" X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or (if multi-class and using sparse inputs) a list of ``n_classes`` Dask Arrays of shape = [n_samples, n_features + 1]",
) )
def to_local(self) -> LGBMClassifier: def to_local(self) -> LGBMClassifier:
...@@ -1355,7 +1317,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1355,7 +1317,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
def __init__( def __init__(
self, self,
boosting_type: str = 'gbdt', boosting_type: str = "gbdt",
num_leaves: int = 31, num_leaves: int = 31,
max_depth: int = -1, max_depth: int = -1,
learning_rate: float = 0.1, learning_rate: float = 0.1,
...@@ -1363,19 +1325,19 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1363,19 +1325,19 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.0,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
min_child_samples: int = 20, min_child_samples: int = 20,
subsample: float = 1., subsample: float = 1.0,
subsample_freq: int = 0, subsample_freq: int = 0,
colsample_bytree: float = 1., colsample_bytree: float = 1.0,
reg_alpha: float = 0., reg_alpha: float = 0.0,
reg_lambda: float = 0., reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None, random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = 'split', importance_type: str = "split",
client: Optional[Client] = None, client: Optional[Client] = None,
**kwargs: Any **kwargs: Any,
): ):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" """Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self.client = client self.client = client
...@@ -1399,11 +1361,11 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1399,11 +1361,11 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
random_state=random_state, random_state=random_state,
n_jobs=n_jobs, n_jobs=n_jobs,
importance_type=importance_type, importance_type=importance_type,
**kwargs **kwargs,
) )
_base_doc = LGBMRegressor.__init__.__doc__ _base_doc = LGBMRegressor.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore
__init__.__doc__ = f""" __init__.__doc__ = f"""
{_before_kwargs}client : dask.distributed.Client or None, optional (default=None) {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
{' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled. {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
...@@ -1424,7 +1386,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1424,7 +1386,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
eval_sample_weight: Optional[List[_DaskVectorLike]] = None, eval_sample_weight: Optional[List[_DaskVectorLike]] = None,
eval_init_score: Optional[List[_DaskVectorLike]] = None, eval_init_score: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
**kwargs: Any **kwargs: Any,
) -> "DaskLGBMRegressor": ) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" """Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
self._lgb_dask_fit( self._lgb_dask_fit(
...@@ -1438,7 +1400,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1438,7 +1400,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
eval_sample_weight=eval_sample_weight, eval_sample_weight=eval_sample_weight,
eval_init_score=eval_init_score, eval_init_score=eval_init_score,
eval_metric=eval_metric, eval_metric=eval_metric,
**kwargs **kwargs,
) )
return self return self
...@@ -1450,18 +1412,15 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1450,18 +1412,15 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
group_shape="Dask Array or Dask Series or None, optional (default=None)", group_shape="Dask Array or Dask Series or None, optional (default=None)",
eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)", eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
eval_init_score_shape="list of Dask Array or Dask Series, or None, optional (default=None)", eval_init_score_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)" eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
) )
# DaskLGBMRegressor does not support group, eval_class_weight, eval_group. # DaskLGBMRegressor does not support group, eval_class_weight, eval_group.
_base_doc = (_base_doc[:_base_doc.find('group :')] _base_doc = _base_doc[: _base_doc.find("group :")] + _base_doc[_base_doc.find("eval_set :") :]
+ _base_doc[_base_doc.find('eval_set :'):])
_base_doc = (_base_doc[:_base_doc.find('eval_class_weight :')] _base_doc = _base_doc[: _base_doc.find("eval_class_weight :")] + _base_doc[_base_doc.find("eval_init_score :") :]
+ _base_doc[_base_doc.find('eval_init_score :'):])
_base_doc = (_base_doc[:_base_doc.find('eval_group :')] _base_doc = _base_doc[: _base_doc.find("eval_group :")] + _base_doc[_base_doc.find("eval_metric :") :]
+ _base_doc[_base_doc.find('eval_metric :'):])
# DaskLGBMRegressor support for callbacks and init_model is not tested # DaskLGBMRegressor support for callbacks and init_model is not tested
fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs
...@@ -1484,7 +1443,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1484,7 +1443,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any,
) -> dask_Array: ) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" """Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
return _predict( return _predict(
...@@ -1497,7 +1456,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1497,7 +1456,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
validate_features=validate_features, validate_features=validate_features,
**kwargs **kwargs,
) )
predict.__doc__ = _lgbmmodel_doc_predict.format( predict.__doc__ = _lgbmmodel_doc_predict.format(
...@@ -1506,7 +1465,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -1506,7 +1465,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
output_name="predicted_result", output_name="predicted_result",
predicted_result_shape="Dask Array of shape = [n_samples]", predicted_result_shape="Dask Array of shape = [n_samples]",
X_leaves_shape="Dask Array of shape = [n_samples, n_trees]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]" X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]",
) )
def to_local(self) -> LGBMRegressor: def to_local(self) -> LGBMRegressor:
...@@ -1525,7 +1484,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1525,7 +1484,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
def __init__( def __init__(
self, self,
boosting_type: str = 'gbdt', boosting_type: str = "gbdt",
num_leaves: int = 31, num_leaves: int = 31,
max_depth: int = -1, max_depth: int = -1,
learning_rate: float = 0.1, learning_rate: float = 0.1,
...@@ -1533,19 +1492,19 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1533,19 +1492,19 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None, class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.0,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
min_child_samples: int = 20, min_child_samples: int = 20,
subsample: float = 1., subsample: float = 1.0,
subsample_freq: int = 0, subsample_freq: int = 0,
colsample_bytree: float = 1., colsample_bytree: float = 1.0,
reg_alpha: float = 0., reg_alpha: float = 0.0,
reg_lambda: float = 0., reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None, random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = 'split', importance_type: str = "split",
client: Optional[Client] = None, client: Optional[Client] = None,
**kwargs: Any **kwargs: Any,
): ):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" """Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self.client = client self.client = client
...@@ -1569,11 +1528,11 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1569,11 +1528,11 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
random_state=random_state, random_state=random_state,
n_jobs=n_jobs, n_jobs=n_jobs,
importance_type=importance_type, importance_type=importance_type,
**kwargs **kwargs,
) )
_base_doc = LGBMRanker.__init__.__doc__ _base_doc = LGBMRanker.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition("**kwargs") # type: ignore
__init__.__doc__ = f""" __init__.__doc__ = f"""
{_before_kwargs}client : dask.distributed.Client or None, optional (default=None) {_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
{' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled. {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
...@@ -1597,7 +1556,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1597,7 +1556,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
eval_group: Optional[List[_DaskVectorLike]] = None, eval_group: Optional[List[_DaskVectorLike]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Union[List[int], Tuple[int, ...]] = (1, 2, 3, 4, 5), eval_at: Union[List[int], Tuple[int, ...]] = (1, 2, 3, 4, 5),
**kwargs: Any **kwargs: Any,
) -> "DaskLGBMRanker": ) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit.""" """Docstring is inherited from the lightgbm.LGBMRanker.fit."""
self._lgb_dask_fit( self._lgb_dask_fit(
...@@ -1614,7 +1573,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1614,7 +1573,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
eval_group=eval_group, eval_group=eval_group,
eval_metric=eval_metric, eval_metric=eval_metric,
eval_at=eval_at, eval_at=eval_at,
**kwargs **kwargs,
) )
return self return self
...@@ -1626,17 +1585,18 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1626,17 +1585,18 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
group_shape="Dask Array or Dask Series or None, optional (default=None)", group_shape="Dask Array or Dask Series or None, optional (default=None)",
eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)", eval_sample_weight_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
eval_init_score_shape="list of Dask Array or Dask Series, or None, optional (default=None)", eval_init_score_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)" eval_group_shape="list of Dask Array or Dask Series, or None, optional (default=None)",
) )
# DaskLGBMRanker does not support eval_class_weight or early stopping # DaskLGBMRanker does not support eval_class_weight or early stopping
_base_doc = (_base_doc[:_base_doc.find('eval_class_weight :')] _base_doc = _base_doc[: _base_doc.find("eval_class_weight :")] + _base_doc[_base_doc.find("eval_init_score :") :]
+ _base_doc[_base_doc.find('eval_init_score :'):])
_base_doc = (_base_doc[:_base_doc.find('feature_name :')] _base_doc = (
+ "eval_at : list or tuple of int, optional (default=(1, 2, 3, 4, 5))\n" _base_doc[: _base_doc.find("feature_name :")]
+ f"{' ':8}The evaluation positions of the specified metric.\n" + "eval_at : list or tuple of int, optional (default=(1, 2, 3, 4, 5))\n"
+ f"{' ':4}{_base_doc[_base_doc.find('feature_name :'):]}") + f"{' ':8}The evaluation positions of the specified metric.\n"
+ f"{' ':4}{_base_doc[_base_doc.find('feature_name :'):]}"
)
# DaskLGBMRanker support for callbacks and init_model is not tested # DaskLGBMRanker support for callbacks and init_model is not tested
fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs fit.__doc__ = f"""{_base_doc[:_base_doc.find('callbacks :')]}**kwargs
...@@ -1659,7 +1619,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1659,7 +1619,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any,
) -> dask_Array: ) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" """Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict( return _predict(
...@@ -1672,7 +1632,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1672,7 +1632,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
validate_features=validate_features, validate_features=validate_features,
**kwargs **kwargs,
) )
predict.__doc__ = _lgbmmodel_doc_predict.format( predict.__doc__ = _lgbmmodel_doc_predict.format(
...@@ -1681,7 +1641,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -1681,7 +1641,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
output_name="predicted_result", output_name="predicted_result",
predicted_result_shape="Dask Array of shape = [n_samples]", predicted_result_shape="Dask Array of shape = [n_samples]",
X_leaves_shape="Dask Array of shape = [n_samples, n_trees]", X_leaves_shape="Dask Array of shape = [n_samples, n_trees]",
X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]" X_SHAP_values_shape="Dask Array of shape = [n_samples, n_features + 1]",
) )
def to_local(self) -> LGBMRanker: def to_local(self) -> LGBMRanker:
......
...@@ -28,9 +28,9 @@ from .basic import ( ...@@ -28,9 +28,9 @@ from .basic import (
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
__all__ = [ __all__ = [
'cv', "cv",
'CVBooster', "CVBooster",
'train', "train",
] ]
...@@ -41,13 +41,13 @@ _LGBM_CustomMetricFunction = Union[ ...@@ -41,13 +41,13 @@ _LGBM_CustomMetricFunction = Union[
], ],
Callable[ Callable[
[np.ndarray, Dataset], [np.ndarray, Dataset],
List[_LGBM_EvalFunctionResultType] List[_LGBM_EvalFunctionResultType],
], ],
] ]
_LGBM_PreprocFunction = Callable[ _LGBM_PreprocFunction = Callable[
[Dataset, Dataset, Dict[str, Any]], [Dataset, Dataset, Dict[str, Any]],
Tuple[Dataset, Dataset, Dict[str, Any]] Tuple[Dataset, Dataset, Dict[str, Any]],
] ]
...@@ -59,10 +59,10 @@ def train( ...@@ -59,10 +59,10 @@ def train(
valid_names: Optional[List[str]] = None, valid_names: Optional[List[str]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None, feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None, init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: _LGBM_FeatureNameConfiguration = 'auto', feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
keep_training_booster: bool = False, keep_training_booster: bool = False,
callbacks: Optional[List[Callable]] = None callbacks: Optional[List[Callable]] = None,
) -> Booster: ) -> Booster:
"""Perform the training with given parameters. """Perform the training with given parameters.
...@@ -169,14 +169,14 @@ def train( ...@@ -169,14 +169,14 @@ def train(
# create predictor first # create predictor first
params = copy.deepcopy(params) params = copy.deepcopy(params)
params = _choose_param_value( params = _choose_param_value(
main_param_name='objective', main_param_name="objective",
params=params, params=params,
default_value=None default_value=None,
) )
fobj: Optional[_LGBM_CustomObjectiveFunction] = None fobj: Optional[_LGBM_CustomObjectiveFunction] = None
if callable(params["objective"]): if callable(params["objective"]):
fobj = params["objective"] fobj = params["objective"]
params["objective"] = 'none' params["objective"] = "none"
for alias in _ConfigAliases.get("num_iterations"): for alias in _ConfigAliases.get("num_iterations"):
if alias in params: if alias in params:
num_boost_round = params.pop(alias) num_boost_round = params.pop(alias)
...@@ -186,33 +186,26 @@ def train( ...@@ -186,33 +186,26 @@ def train(
params = _choose_param_value( params = _choose_param_value(
main_param_name="early_stopping_round", main_param_name="early_stopping_round",
params=params, params=params,
default_value=None default_value=None,
) )
if params["early_stopping_round"] is None: if params["early_stopping_round"] is None:
params.pop("early_stopping_round") params.pop("early_stopping_round")
first_metric_only = params.get('first_metric_only', False) first_metric_only = params.get("first_metric_only", False)
predictor: Optional[_InnerPredictor] = None predictor: Optional[_InnerPredictor] = None
if isinstance(init_model, (str, Path)): if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor.from_model_file( predictor = _InnerPredictor.from_model_file(model_file=init_model, pred_parameter=params)
model_file=init_model,
pred_parameter=params
)
elif isinstance(init_model, Booster): elif isinstance(init_model, Booster):
predictor = _InnerPredictor.from_booster( predictor = _InnerPredictor.from_booster(booster=init_model, pred_parameter=dict(init_model.params, **params))
booster=init_model,
pred_parameter=dict(init_model.params, **params)
)
if predictor is not None: if predictor is not None:
init_iteration = predictor.current_iteration() init_iteration = predictor.current_iteration()
else: else:
init_iteration = 0 init_iteration = 0
train_set._update_params(params) \ train_set._update_params(params)._set_predictor(predictor).set_feature_name(feature_name).set_categorical_feature(
._set_predictor(predictor) \ categorical_feature
.set_feature_name(feature_name) \ )
.set_categorical_feature(categorical_feature)
is_valid_contain_train = False is_valid_contain_train = False
train_data_name = "training" train_data_name = "training"
...@@ -234,13 +227,13 @@ def train( ...@@ -234,13 +227,13 @@ def train(
if valid_names is not None and len(valid_names) > i: if valid_names is not None and len(valid_names) > i:
name_valid_sets.append(valid_names[i]) name_valid_sets.append(valid_names[i])
else: else:
name_valid_sets.append(f'valid_{i}') name_valid_sets.append(f"valid_{i}")
# process callbacks # process callbacks
if callbacks is None: if callbacks is None:
callbacks_set = set() callbacks_set = set()
else: else:
for i, cb in enumerate(callbacks): for i, cb in enumerate(callbacks):
cb.__dict__.setdefault('order', i - len(callbacks)) cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks) callbacks_set = set(callbacks)
if "early_stopping_round" in params: if "early_stopping_round" in params:
...@@ -251,15 +244,16 @@ def train( ...@@ -251,15 +244,16 @@ def train(
verbose=_choose_param_value( verbose=_choose_param_value(
main_param_name="verbosity", main_param_name="verbosity",
params=params, params=params,
default_value=1 default_value=1,
).pop("verbosity") > 0 ).pop("verbosity")
> 0,
) )
) )
callbacks_before_iter_set = {cb for cb in callbacks_set if getattr(cb, 'before_iteration', False)} callbacks_before_iter_set = {cb for cb in callbacks_set if getattr(cb, "before_iteration", False)}
callbacks_after_iter_set = callbacks_set - callbacks_before_iter_set callbacks_after_iter_set = callbacks_set - callbacks_before_iter_set
callbacks_before_iter = sorted(callbacks_before_iter_set, key=attrgetter('order')) callbacks_before_iter = sorted(callbacks_before_iter_set, key=attrgetter("order"))
callbacks_after_iter = sorted(callbacks_after_iter_set, key=attrgetter('order')) callbacks_after_iter = sorted(callbacks_after_iter_set, key=attrgetter("order"))
# construct booster # construct booster
try: try:
...@@ -277,12 +271,16 @@ def train( ...@@ -277,12 +271,16 @@ def train(
# start training # start training
for i in range(init_iteration, init_iteration + num_boost_round): for i in range(init_iteration, init_iteration + num_boost_round):
for cb in callbacks_before_iter: for cb in callbacks_before_iter:
cb(callback.CallbackEnv(model=booster, cb(
params=params, callback.CallbackEnv(
iteration=i, model=booster,
begin_iteration=init_iteration, params=params,
end_iteration=init_iteration + num_boost_round, iteration=i,
evaluation_result_list=None)) begin_iteration=init_iteration,
end_iteration=init_iteration + num_boost_round,
evaluation_result_list=None,
)
)
booster.update(fobj=fobj) booster.update(fobj=fobj)
...@@ -294,12 +292,16 @@ def train( ...@@ -294,12 +292,16 @@ def train(
evaluation_result_list.extend(booster.eval_valid(feval)) evaluation_result_list.extend(booster.eval_valid(feval))
try: try:
for cb in callbacks_after_iter: for cb in callbacks_after_iter:
cb(callback.CallbackEnv(model=booster, cb(
params=params, callback.CallbackEnv(
iteration=i, model=booster,
begin_iteration=init_iteration, params=params,
end_iteration=init_iteration + num_boost_round, iteration=i,
evaluation_result_list=evaluation_result_list)) begin_iteration=init_iteration,
end_iteration=init_iteration + num_boost_round,
evaluation_result_list=evaluation_result_list,
)
)
except callback.EarlyStopException as earlyStopException: except callback.EarlyStopException as earlyStopException:
booster.best_iteration = earlyStopException.best_iteration + 1 booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score evaluation_result_list = earlyStopException.best_score
...@@ -334,7 +336,7 @@ class CVBooster: ...@@ -334,7 +336,7 @@ class CVBooster:
def __init__( def __init__(
self, self,
model_file: Optional[Union[str, Path]] = None model_file: Optional[Union[str, Path]] = None,
): ):
"""Initialize the CVBooster. """Initialize the CVBooster.
...@@ -361,18 +363,23 @@ class CVBooster: ...@@ -361,18 +363,23 @@ class CVBooster:
"""Serialize CVBooster to dict.""" """Serialize CVBooster to dict."""
models_str = [] models_str = []
for booster in self.boosters: for booster in self.boosters:
models_str.append(booster.model_to_string(num_iteration=num_iteration, start_iteration=start_iteration, models_str.append(
importance_type=importance_type)) booster.model_to_string(
num_iteration=num_iteration, start_iteration=start_iteration, importance_type=importance_type
)
)
return {"boosters": models_str, "best_iteration": self.best_iteration} return {"boosters": models_str, "best_iteration": self.best_iteration}
def __getattr__(self, name: str) -> Callable[[Any, Any], List[Any]]: def __getattr__(self, name: str) -> Callable[[Any, Any], List[Any]]:
"""Redirect methods call of CVBooster.""" """Redirect methods call of CVBooster."""
def handler_function(*args: Any, **kwargs: Any) -> List[Any]: def handler_function(*args: Any, **kwargs: Any) -> List[Any]:
"""Call methods with each booster, and concatenate their results.""" """Call methods with each booster, and concatenate their results."""
ret = [] ret = []
for booster in self.boosters: for booster in self.boosters:
ret.append(getattr(booster, name)(*args, **kwargs)) ret.append(getattr(booster, name)(*args, **kwargs))
return ret return ret
return handler_function return handler_function
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:
...@@ -401,7 +408,7 @@ class CVBooster: ...@@ -401,7 +408,7 @@ class CVBooster:
self, self,
num_iteration: Optional[int] = None, num_iteration: Optional[int] = None,
start_iteration: int = 0, start_iteration: int = 0,
importance_type: str = 'split' importance_type: str = "split",
) -> str: ) -> str:
"""Save CVBooster to JSON string. """Save CVBooster to JSON string.
...@@ -430,7 +437,7 @@ class CVBooster: ...@@ -430,7 +437,7 @@ class CVBooster:
filename: Union[str, Path], filename: Union[str, Path],
num_iteration: Optional[int] = None, num_iteration: Optional[int] = None,
start_iteration: int = 0, start_iteration: int = 0,
importance_type: str = 'split' importance_type: str = "split",
) -> "CVBooster": ) -> "CVBooster":
"""Save CVBooster to a file as JSON text. """Save CVBooster to a file as JSON text.
...@@ -469,16 +476,18 @@ def _make_n_folds( ...@@ -469,16 +476,18 @@ def _make_n_folds(
fpreproc: Optional[_LGBM_PreprocFunction], fpreproc: Optional[_LGBM_PreprocFunction],
stratified: bool, stratified: bool,
shuffle: bool, shuffle: bool,
eval_train_metric: bool eval_train_metric: bool,
) -> CVBooster: ) -> CVBooster:
"""Make a n-fold list of Booster from random indices.""" """Make a n-fold list of Booster from random indices."""
full_data = full_data.construct() full_data = full_data.construct()
num_data = full_data.num_data() num_data = full_data.num_data()
if folds is not None: if folds is not None:
if not hasattr(folds, '__iter__') and not hasattr(folds, 'split'): if not hasattr(folds, "__iter__") and not hasattr(folds, "split"):
raise AttributeError("folds should be a generator or iterator of (train_idx, test_idx) tuples " raise AttributeError(
"or scikit-learn splitter object with split method") "folds should be a generator or iterator of (train_idx, test_idx) tuples "
if hasattr(folds, 'split'): "or scikit-learn splitter object with split method"
)
if hasattr(folds, "split"):
group_info = full_data.get_group() group_info = full_data.get_group()
if group_info is not None: if group_info is not None:
group_info = np.array(group_info, dtype=np.int32, copy=False) group_info = np.array(group_info, dtype=np.int32, copy=False)
...@@ -487,11 +496,13 @@ def _make_n_folds( ...@@ -487,11 +496,13 @@ def _make_n_folds(
flatted_group = np.zeros(num_data, dtype=np.int32) flatted_group = np.zeros(num_data, dtype=np.int32)
folds = folds.split(X=np.empty(num_data), y=full_data.get_label(), groups=flatted_group) folds = folds.split(X=np.empty(num_data), y=full_data.get_label(), groups=flatted_group)
else: else:
if any(params.get(obj_alias, "") in {"lambdarank", "rank_xendcg", "xendcg", if any(
"xe_ndcg", "xe_ndcg_mart", "xendcg_mart"} params.get(obj_alias, "")
for obj_alias in _ConfigAliases.get("objective")): in {"lambdarank", "rank_xendcg", "xendcg", "xe_ndcg", "xe_ndcg_mart", "xendcg_mart"}
for obj_alias in _ConfigAliases.get("objective")
):
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise LightGBMError('scikit-learn is required for ranking cv') raise LightGBMError("scikit-learn is required for ranking cv")
# ranking task, split according to groups # ranking task, split according to groups
group_info = np.array(full_data.get_group(), dtype=np.int32, copy=False) group_info = np.array(full_data.get_group(), dtype=np.int32, copy=False)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info) flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
...@@ -499,7 +510,7 @@ def _make_n_folds( ...@@ -499,7 +510,7 @@ def _make_n_folds(
folds = group_kfold.split(X=np.empty(num_data), groups=flatted_group) folds = group_kfold.split(X=np.empty(num_data), groups=flatted_group)
elif stratified: elif stratified:
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise LightGBMError('scikit-learn is required for stratified cv') raise LightGBMError("scikit-learn is required for stratified cv")
skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed) skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = skf.split(X=np.empty(num_data), y=full_data.get_label()) folds = skf.split(X=np.empty(num_data), y=full_data.get_label())
else: else:
...@@ -508,7 +519,7 @@ def _make_n_folds( ...@@ -508,7 +519,7 @@ def _make_n_folds(
else: else:
randidx = np.arange(num_data) randidx = np.arange(num_data)
kstep = int(num_data / nfold) kstep = int(num_data / nfold)
test_id = [randidx[i: i + kstep] for i in range(0, num_data, kstep)] test_id = [randidx[i : i + kstep] for i in range(0, num_data, kstep)]
train_id = [np.concatenate([test_id[i] for i in range(nfold) if k != i]) for k in range(nfold)] train_id = [np.concatenate([test_id[i] for i in range(nfold) if k != i]) for k in range(nfold)]
folds = zip(train_id, test_id) folds = zip(train_id, test_id)
...@@ -523,14 +534,14 @@ def _make_n_folds( ...@@ -523,14 +534,14 @@ def _make_n_folds(
tparam = params tparam = params
booster_for_fold = Booster(tparam, train_set) booster_for_fold = Booster(tparam, train_set)
if eval_train_metric: if eval_train_metric:
booster_for_fold.add_valid(train_set, 'train') booster_for_fold.add_valid(train_set, "train")
booster_for_fold.add_valid(valid_set, 'valid') booster_for_fold.add_valid(valid_set, "valid")
ret.boosters.append(booster_for_fold) ret.boosters.append(booster_for_fold)
return ret return ret
def _agg_cv_result( def _agg_cv_result(
raw_results: List[List[_LGBM_BoosterEvalMethodResultType]] raw_results: List[List[_LGBM_BoosterEvalMethodResultType]],
) -> List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]: ) -> List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]:
"""Aggregate cross-validation results.""" """Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = OrderedDict() cvmap: Dict[str, List[float]] = OrderedDict()
...@@ -541,7 +552,7 @@ def _agg_cv_result( ...@@ -541,7 +552,7 @@ def _agg_cv_result(
metric_type[key] = one_line[3] metric_type[key] = one_line[3]
cvmap.setdefault(key, []) cvmap.setdefault(key, [])
cvmap[key].append(one_line[2]) cvmap[key].append(one_line[2])
return [('cv_agg', k, float(np.mean(v)), metric_type[k], float(np.std(v))) for k, v in cvmap.items()] return [("cv_agg", k, float(np.mean(v)), metric_type[k], float(np.std(v))) for k, v in cvmap.items()]
def cv( def cv(
...@@ -555,13 +566,13 @@ def cv( ...@@ -555,13 +566,13 @@ def cv(
metrics: Optional[Union[str, List[str]]] = None, metrics: Optional[Union[str, List[str]]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None, feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None, init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: _LGBM_FeatureNameConfiguration = 'auto', feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
fpreproc: Optional[_LGBM_PreprocFunction] = None, fpreproc: Optional[_LGBM_PreprocFunction] = None,
seed: int = 0, seed: int = 0,
callbacks: Optional[List[Callable]] = None, callbacks: Optional[List[Callable]] = None,
eval_train_metric: bool = False, eval_train_metric: bool = False,
return_cvbooster: bool = False return_cvbooster: bool = False,
) -> Dict[str, Union[List[float], CVBooster]]: ) -> Dict[str, Union[List[float], CVBooster]]:
"""Perform the cross-validation with given parameters. """Perform the cross-validation with given parameters.
...@@ -683,14 +694,14 @@ def cv( ...@@ -683,14 +694,14 @@ def cv(
params = copy.deepcopy(params) params = copy.deepcopy(params)
params = _choose_param_value( params = _choose_param_value(
main_param_name='objective', main_param_name="objective",
params=params, params=params,
default_value=None default_value=None,
) )
fobj: Optional[_LGBM_CustomObjectiveFunction] = None fobj: Optional[_LGBM_CustomObjectiveFunction] = None
if callable(params["objective"]): if callable(params["objective"]):
fobj = params["objective"] fobj = params["objective"]
params["objective"] = 'none' params["objective"] = "none"
for alias in _ConfigAliases.get("num_iterations"): for alias in _ConfigAliases.get("num_iterations"):
if alias in params: if alias in params:
_log_warning(f"Found '{alias}' in params. Will use it instead of 'num_boost_round' argument") _log_warning(f"Found '{alias}' in params. Will use it instead of 'num_boost_round' argument")
...@@ -700,21 +711,21 @@ def cv( ...@@ -700,21 +711,21 @@ def cv(
params = _choose_param_value( params = _choose_param_value(
main_param_name="early_stopping_round", main_param_name="early_stopping_round",
params=params, params=params,
default_value=None default_value=None,
) )
if params["early_stopping_round"] is None: if params["early_stopping_round"] is None:
params.pop("early_stopping_round") params.pop("early_stopping_round")
first_metric_only = params.get('first_metric_only', False) first_metric_only = params.get("first_metric_only", False)
if isinstance(init_model, (str, Path)): if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor.from_model_file( predictor = _InnerPredictor.from_model_file(
model_file=init_model, model_file=init_model,
pred_parameter=params pred_parameter=params,
) )
elif isinstance(init_model, Booster): elif isinstance(init_model, Booster):
predictor = _InnerPredictor.from_booster( predictor = _InnerPredictor.from_booster(
booster=init_model, booster=init_model,
pred_parameter=dict(init_model.params, **params) pred_parameter=dict(init_model.params, **params),
) )
else: else:
predictor = None predictor = None
...@@ -722,25 +733,31 @@ def cv( ...@@ -722,25 +733,31 @@ def cv(
if metrics is not None: if metrics is not None:
for metric_alias in _ConfigAliases.get("metric"): for metric_alias in _ConfigAliases.get("metric"):
params.pop(metric_alias, None) params.pop(metric_alias, None)
params['metric'] = metrics params["metric"] = metrics
train_set._update_params(params) \ train_set._update_params(params)._set_predictor(predictor).set_feature_name(feature_name).set_categorical_feature(
._set_predictor(predictor) \ categorical_feature
.set_feature_name(feature_name) \ )
.set_categorical_feature(categorical_feature)
results = defaultdict(list) results = defaultdict(list)
cvfolds = _make_n_folds(full_data=train_set, folds=folds, nfold=nfold, cvfolds = _make_n_folds(
params=params, seed=seed, fpreproc=fpreproc, full_data=train_set,
stratified=stratified, shuffle=shuffle, folds=folds,
eval_train_metric=eval_train_metric) nfold=nfold,
params=params,
seed=seed,
fpreproc=fpreproc,
stratified=stratified,
shuffle=shuffle,
eval_train_metric=eval_train_metric,
)
# setup callbacks # setup callbacks
if callbacks is None: if callbacks is None:
callbacks_set = set() callbacks_set = set()
else: else:
for i, cb in enumerate(callbacks): for i, cb in enumerate(callbacks):
cb.__dict__.setdefault('order', i - len(callbacks)) cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks) callbacks_set = set(callbacks)
if "early_stopping_round" in params: if "early_stopping_round" in params:
...@@ -751,46 +768,55 @@ def cv( ...@@ -751,46 +768,55 @@ def cv(
verbose=_choose_param_value( verbose=_choose_param_value(
main_param_name="verbosity", main_param_name="verbosity",
params=params, params=params,
default_value=1 default_value=1,
).pop("verbosity") > 0 ).pop("verbosity")
> 0,
) )
) )
callbacks_before_iter_set = {cb for cb in callbacks_set if getattr(cb, 'before_iteration', False)} callbacks_before_iter_set = {cb for cb in callbacks_set if getattr(cb, "before_iteration", False)}
callbacks_after_iter_set = callbacks_set - callbacks_before_iter_set callbacks_after_iter_set = callbacks_set - callbacks_before_iter_set
callbacks_before_iter = sorted(callbacks_before_iter_set, key=attrgetter('order')) callbacks_before_iter = sorted(callbacks_before_iter_set, key=attrgetter("order"))
callbacks_after_iter = sorted(callbacks_after_iter_set, key=attrgetter('order')) callbacks_after_iter = sorted(callbacks_after_iter_set, key=attrgetter("order"))
for i in range(num_boost_round): for i in range(num_boost_round):
for cb in callbacks_before_iter: for cb in callbacks_before_iter:
cb(callback.CallbackEnv(model=cvfolds, cb(
params=params, callback.CallbackEnv(
iteration=i, model=cvfolds,
begin_iteration=0, params=params,
end_iteration=num_boost_round, iteration=i,
evaluation_result_list=None)) begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None,
)
)
cvfolds.update(fobj=fobj) # type: ignore[call-arg] cvfolds.update(fobj=fobj) # type: ignore[call-arg]
res = _agg_cv_result(cvfolds.eval_valid(feval)) # type: ignore[call-arg] res = _agg_cv_result(cvfolds.eval_valid(feval)) # type: ignore[call-arg]
for _, key, mean, _, std in res: for _, key, mean, _, std in res:
results[f'{key}-mean'].append(mean) results[f"{key}-mean"].append(mean)
results[f'{key}-stdv'].append(std) results[f"{key}-stdv"].append(std)
try: try:
for cb in callbacks_after_iter: for cb in callbacks_after_iter:
cb(callback.CallbackEnv(model=cvfolds, cb(
params=params, callback.CallbackEnv(
iteration=i, model=cvfolds,
begin_iteration=0, params=params,
end_iteration=num_boost_round, iteration=i,
evaluation_result_list=res)) begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=res,
)
)
except callback.EarlyStopException as earlyStopException: except callback.EarlyStopException as earlyStopException:
cvfolds.best_iteration = earlyStopException.best_iteration + 1 cvfolds.best_iteration = earlyStopException.best_iteration + 1
for bst in cvfolds.boosters: for bst in cvfolds.boosters:
bst.best_iteration = cvfolds.best_iteration bst.best_iteration = cvfolds.best_iteration
for k in results: for k in results:
results[k] = results[k][:cvfolds.best_iteration] results[k] = results[k][: cvfolds.best_iteration]
break break
if return_cvbooster: if return_cvbooster:
results['cvbooster'] = cvfolds # type: ignore[assignment] results["cvbooster"] = cvfolds # type: ignore[assignment]
return dict(results) return dict(results)
...@@ -16,17 +16,19 @@ def find_lib_path() -> List[str]: ...@@ -16,17 +16,19 @@ def find_lib_path() -> List[str]:
List of all found library paths to LightGBM. List of all found library paths to LightGBM.
""" """
curr_path = Path(__file__).absolute() curr_path = Path(__file__).absolute()
dll_path = [curr_path.parents[1], dll_path = [
curr_path.parents[0] / 'bin', curr_path.parents[1],
curr_path.parents[0] / 'lib'] curr_path.parents[0] / "bin",
if system() in ('Windows', 'Microsoft'): curr_path.parents[0] / "lib",
dll_path.append(curr_path.parents[1] / 'Release') ]
dll_path.append(curr_path.parents[1] / 'windows' / 'x64' / 'DLL') if system() in ("Windows", "Microsoft"):
dll_path = [p / 'lib_lightgbm.dll' for p in dll_path] dll_path.append(curr_path.parents[1] / "Release")
dll_path.append(curr_path.parents[1] / "windows" / "x64" / "DLL")
dll_path = [p / "lib_lightgbm.dll" for p in dll_path]
else: else:
dll_path = [p / 'lib_lightgbm.so' for p in dll_path] dll_path = [p / "lib_lightgbm.so" for p in dll_path]
lib_path = [str(p) for p in dll_path if p.is_file()] lib_path = [str(p) for p in dll_path if p.is_file()]
if not lib_path: if not lib_path:
dll_path_joined = '\n'.join(map(str, dll_path)) dll_path_joined = "\n".join(map(str, dll_path))
raise Exception(f'Cannot find lightgbm library file in following paths:\n{dll_path_joined}') raise Exception(f"Cannot find lightgbm library file in following paths:\n{dll_path_joined}")
return lib_path return lib_path
...@@ -12,11 +12,11 @@ from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, pd_DataFrame ...@@ -12,11 +12,11 @@ from .compat import GRAPHVIZ_INSTALLED, MATPLOTLIB_INSTALLED, pd_DataFrame
from .sklearn import LGBMModel from .sklearn import LGBMModel
__all__ = [ __all__ = [
'create_tree_digraph', "create_tree_digraph",
'plot_importance', "plot_importance",
'plot_metric', "plot_metric",
'plot_split_value_histogram', "plot_split_value_histogram",
'plot_tree', "plot_tree",
] ]
...@@ -27,9 +27,7 @@ def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None: ...@@ -27,9 +27,7 @@ def _check_not_tuple_of_2_elements(obj: Any, obj_name: str) -> None:
def _float2str(value: float, precision: Optional[int]) -> str: def _float2str(value: float, precision: Optional[int]) -> str:
return (f"{value:.{precision}f}" return f"{value:.{precision}f}" if precision is not None and not isinstance(value, str) else str(value)
if precision is not None and not isinstance(value, str)
else str(value))
def plot_importance( def plot_importance(
...@@ -38,17 +36,17 @@ def plot_importance( ...@@ -38,17 +36,17 @@ def plot_importance(
height: float = 0.2, height: float = 0.2,
xlim: Optional[Tuple[float, float]] = None, xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = 'Feature importance', title: Optional[str] = "Feature importance",
xlabel: Optional[str] = 'Feature importance', xlabel: Optional[str] = "Feature importance",
ylabel: Optional[str] = 'Features', ylabel: Optional[str] = "Features",
importance_type: str = 'auto', importance_type: str = "auto",
max_num_features: Optional[int] = None, max_num_features: Optional[int] = None,
ignore_zero: bool = True, ignore_zero: bool = True,
figsize: Optional[Tuple[float, float]] = None, figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None, dpi: Optional[int] = None,
grid: bool = True, grid: bool = True,
precision: Optional[int] = 3, precision: Optional[int] = 3,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
"""Plot model's feature importances. """Plot model's feature importances.
...@@ -104,7 +102,7 @@ def plot_importance( ...@@ -104,7 +102,7 @@ def plot_importance(
if MATPLOTLIB_INSTALLED: if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
else: else:
raise ImportError('You must install matplotlib and restart your session to plot importance.') raise ImportError("You must install matplotlib and restart your session to plot importance.")
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
if importance_type == "auto": if importance_type == "auto":
...@@ -114,7 +112,7 @@ def plot_importance( ...@@ -114,7 +112,7 @@ def plot_importance(
if importance_type == "auto": if importance_type == "auto":
importance_type = "split" importance_type = "split"
else: else:
raise TypeError('booster must be Booster or LGBMModel.') raise TypeError("booster must be Booster or LGBMModel.")
importance = booster.feature_importance(importance_type=importance_type) importance = booster.feature_importance(importance_type=importance_type)
feature_name = booster.feature_name() feature_name = booster.feature_name()
...@@ -131,28 +129,26 @@ def plot_importance( ...@@ -131,28 +129,26 @@ def plot_importance(
if ax is None: if ax is None:
if figsize is not None: if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize') _check_not_tuple_of_2_elements(figsize, "figsize")
_, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
ylocs = np.arange(len(values)) ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs) ax.barh(ylocs, values, align="center", height=height, **kwargs)
for x, y in zip(values, ylocs): for x, y in zip(values, ylocs):
ax.text(x + 1, y, ax.text(x + 1, y, _float2str(x, precision) if importance_type == "gain" else x, va="center")
_float2str(x, precision) if importance_type == 'gain' else x,
va='center')
ax.set_yticks(ylocs) ax.set_yticks(ylocs)
ax.set_yticklabels(labels) ax.set_yticklabels(labels)
if xlim is not None: if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim') _check_not_tuple_of_2_elements(xlim, "xlim")
else: else:
xlim = (0, max(values) * 1.1) xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim) ax.set_xlim(xlim)
if ylim is not None: if ylim is not None:
_check_not_tuple_of_2_elements(ylim, 'ylim') _check_not_tuple_of_2_elements(ylim, "ylim")
else: else:
ylim = (-1, len(values)) ylim = (-1, len(values))
ax.set_ylim(ylim) ax.set_ylim(ylim)
...@@ -160,7 +156,7 @@ def plot_importance( ...@@ -160,7 +156,7 @@ def plot_importance(
if title is not None: if title is not None:
ax.set_title(title) ax.set_title(title)
if xlabel is not None: if xlabel is not None:
xlabel = xlabel.replace('@importance_type@', importance_type) xlabel = xlabel.replace("@importance_type@", importance_type)
ax.set_xlabel(xlabel) ax.set_xlabel(xlabel)
if ylabel is not None: if ylabel is not None:
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
...@@ -176,13 +172,13 @@ def plot_split_value_histogram( ...@@ -176,13 +172,13 @@ def plot_split_value_histogram(
width_coef: float = 0.8, width_coef: float = 0.8,
xlim: Optional[Tuple[float, float]] = None, xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = 'Split value histogram for feature with @index/name@ @feature@', title: Optional[str] = "Split value histogram for feature with @index/name@ @feature@",
xlabel: Optional[str] = 'Feature split value', xlabel: Optional[str] = "Feature split value",
ylabel: Optional[str] = 'Count', ylabel: Optional[str] = "Count",
figsize: Optional[Tuple[float, float]] = None, figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None, dpi: Optional[int] = None,
grid: bool = True, grid: bool = True,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
"""Plot split value histogram for the specified feature of the model. """Plot split value histogram for the specified feature of the model.
...@@ -238,29 +234,28 @@ def plot_split_value_histogram( ...@@ -238,29 +234,28 @@ def plot_split_value_histogram(
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator from matplotlib.ticker import MaxNLocator
else: else:
raise ImportError('You must install matplotlib and restart your session to plot split value histogram.') raise ImportError("You must install matplotlib and restart your session to plot split value histogram.")
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
booster = booster.booster_ booster = booster.booster_
elif not isinstance(booster, Booster): elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.') raise TypeError("booster must be Booster or LGBMModel.")
hist, split_bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False) hist, split_bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
if np.count_nonzero(hist) == 0: if np.count_nonzero(hist) == 0:
raise ValueError('Cannot plot split value histogram, ' raise ValueError("Cannot plot split value histogram, " f"because feature {feature} was not used in splitting")
f'because feature {feature} was not used in splitting')
width = width_coef * (split_bins[1] - split_bins[0]) width = width_coef * (split_bins[1] - split_bins[0])
centred = (split_bins[:-1] + split_bins[1:]) / 2 centred = (split_bins[:-1] + split_bins[1:]) / 2
if ax is None: if ax is None:
if figsize is not None: if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize') _check_not_tuple_of_2_elements(figsize, "figsize")
_, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
ax.bar(centred, hist, align='center', width=width, **kwargs) ax.bar(centred, hist, align="center", width=width, **kwargs)
if xlim is not None: if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim') _check_not_tuple_of_2_elements(xlim, "xlim")
else: else:
range_result = split_bins[-1] - split_bins[0] range_result = split_bins[-1] - split_bins[0]
xlim = (split_bins[0] - range_result * 0.2, split_bins[-1] + range_result * 0.2) xlim = (split_bins[0] - range_result * 0.2, split_bins[-1] + range_result * 0.2)
...@@ -268,14 +263,14 @@ def plot_split_value_histogram( ...@@ -268,14 +263,14 @@ def plot_split_value_histogram(
ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True))
if ylim is not None: if ylim is not None:
_check_not_tuple_of_2_elements(ylim, 'ylim') _check_not_tuple_of_2_elements(ylim, "ylim")
else: else:
ylim = (0, max(hist) * 1.1) ylim = (0, max(hist) * 1.1)
ax.set_ylim(ylim) ax.set_ylim(ylim)
if title is not None: if title is not None:
title = title.replace('@feature@', str(feature)) title = title.replace("@feature@", str(feature))
title = title.replace('@index/name@', ('name' if isinstance(feature, str) else 'index')) title = title.replace("@index/name@", ("name" if isinstance(feature, str) else "index"))
ax.set_title(title) ax.set_title(title)
if xlabel is not None: if xlabel is not None:
ax.set_xlabel(xlabel) ax.set_xlabel(xlabel)
...@@ -292,12 +287,12 @@ def plot_metric( ...@@ -292,12 +287,12 @@ def plot_metric(
ax=None, ax=None,
xlim: Optional[Tuple[float, float]] = None, xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None, ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = 'Metric during training', title: Optional[str] = "Metric during training",
xlabel: Optional[str] = 'Iterations', xlabel: Optional[str] = "Iterations",
ylabel: Optional[str] = '@metric@', ylabel: Optional[str] = "@metric@",
figsize: Optional[Tuple[float, float]] = None, figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None, dpi: Optional[int] = None,
grid: bool = True grid: bool = True,
) -> Any: ) -> Any:
"""Plot one metric during training. """Plot one metric during training.
...@@ -345,31 +340,33 @@ def plot_metric( ...@@ -345,31 +340,33 @@ def plot_metric(
if MATPLOTLIB_INSTALLED: if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
else: else:
raise ImportError('You must install matplotlib and restart your session to plot metric.') raise ImportError("You must install matplotlib and restart your session to plot metric.")
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
eval_results = deepcopy(booster.evals_result_) eval_results = deepcopy(booster.evals_result_)
elif isinstance(booster, dict): elif isinstance(booster, dict):
eval_results = deepcopy(booster) eval_results = deepcopy(booster)
elif isinstance(booster, Booster): elif isinstance(booster, Booster):
raise TypeError("booster must be dict or LGBMModel. To use plot_metric with Booster type, first record the metrics using record_evaluation callback then pass that to plot_metric as argument `booster`") raise TypeError(
"booster must be dict or LGBMModel. To use plot_metric with Booster type, first record the metrics using record_evaluation callback then pass that to plot_metric as argument `booster`"
)
else: else:
raise TypeError('booster must be dict or LGBMModel.') raise TypeError("booster must be dict or LGBMModel.")
num_data = len(eval_results) num_data = len(eval_results)
if not num_data: if not num_data:
raise ValueError('eval results cannot be empty.') raise ValueError("eval results cannot be empty.")
if ax is None: if ax is None:
if figsize is not None: if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize') _check_not_tuple_of_2_elements(figsize, "figsize")
_, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
if dataset_names is None: if dataset_names is None:
dataset_names_iter = iter(eval_results.keys()) dataset_names_iter = iter(eval_results.keys())
elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names: elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
raise ValueError('dataset_names should be iterable and cannot be empty') raise ValueError("dataset_names should be iterable and cannot be empty")
else: else:
dataset_names_iter = iter(dataset_names) dataset_names_iter = iter(dataset_names)
...@@ -382,7 +379,7 @@ def plot_metric( ...@@ -382,7 +379,7 @@ def plot_metric(
metric, results = metrics_for_one.popitem() metric, results = metrics_for_one.popitem()
else: else:
if metric not in metrics_for_one: if metric not in metrics_for_one:
raise KeyError('No given metric in eval results.') raise KeyError("No given metric in eval results.")
results = metrics_for_one[metric] results = metrics_for_one[metric]
num_iteration = len(results) num_iteration = len(results)
max_result = max(results) max_result = max(results)
...@@ -397,16 +394,16 @@ def plot_metric( ...@@ -397,16 +394,16 @@ def plot_metric(
min_result = min(*results, min_result) min_result = min(*results, min_result)
ax.plot(x_, results, label=name) ax.plot(x_, results, label=name)
ax.legend(loc='best') ax.legend(loc="best")
if xlim is not None: if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim') _check_not_tuple_of_2_elements(xlim, "xlim")
else: else:
xlim = (0, num_iteration) xlim = (0, num_iteration)
ax.set_xlim(xlim) ax.set_xlim(xlim)
if ylim is not None: if ylim is not None:
_check_not_tuple_of_2_elements(ylim, 'ylim') _check_not_tuple_of_2_elements(ylim, "ylim")
else: else:
range_result = max_result - min_result range_result = max_result - min_result
ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2) ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
...@@ -417,7 +414,7 @@ def plot_metric( ...@@ -417,7 +414,7 @@ def plot_metric(
if xlabel is not None: if xlabel is not None:
ax.set_xlabel(xlabel) ax.set_xlabel(xlabel)
if ylabel is not None: if ylabel is not None:
ylabel = ylabel.replace('@metric@', metric) ylabel = ylabel.replace("@metric@", metric)
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.grid(grid) ax.grid(grid)
return ax return ax
...@@ -432,19 +429,20 @@ def _determine_direction_for_numeric_split( ...@@ -432,19 +429,20 @@ def _determine_direction_for_numeric_split(
missing_type = _MissingType(missing_type_str) missing_type = _MissingType(missing_type_str)
if math.isnan(fval) and missing_type != _MissingType.NAN: if math.isnan(fval) and missing_type != _MissingType.NAN:
fval = 0.0 fval = 0.0
if ((missing_type == _MissingType.ZERO and _is_zero(fval)) if (missing_type == _MissingType.ZERO and _is_zero(fval)) or (
or (missing_type == _MissingType.NAN and math.isnan(fval))): missing_type == _MissingType.NAN and math.isnan(fval)
direction = 'left' if default_left else 'right' ):
direction = "left" if default_left else "right"
else: else:
direction = 'left' if fval <= threshold else 'right' direction = "left" if fval <= threshold else "right"
return direction return direction
def _determine_direction_for_categorical_split(fval: float, thresholds: str) -> str: def _determine_direction_for_categorical_split(fval: float, thresholds: str) -> str:
if math.isnan(fval) or int(fval) < 0: if math.isnan(fval) or int(fval) < 0:
return 'right' return "right"
int_thresholds = {int(t) for t in thresholds.split('||')} int_thresholds = {int(t) for t in thresholds.split("||")}
return 'left' if int(fval) in int_thresholds else 'right' return "left" if int(fval) in int_thresholds else "right"
def _to_graphviz( def _to_graphviz(
...@@ -456,7 +454,7 @@ def _to_graphviz( ...@@ -456,7 +454,7 @@ def _to_graphviz(
constraints: Optional[List[int]], constraints: Optional[List[int]],
example_case: Optional[Union[np.ndarray, pd_DataFrame]], example_case: Optional[Union[np.ndarray, pd_DataFrame]],
max_category_values: int, max_category_values: int,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
"""Convert specified tree to graphviz instance. """Convert specified tree to graphviz instance.
...@@ -466,120 +464,124 @@ def _to_graphviz( ...@@ -466,120 +464,124 @@ def _to_graphviz(
if GRAPHVIZ_INSTALLED: if GRAPHVIZ_INSTALLED:
from graphviz import Digraph from graphviz import Digraph
else: else:
raise ImportError('You must install graphviz and restart your session to plot tree.') raise ImportError("You must install graphviz and restart your session to plot tree.")
def add( def add(
root: Dict[str, Any], root: Dict[str, Any], total_count: int, parent: Optional[str], decision: Optional[str], highlight: bool
total_count: int,
parent: Optional[str],
decision: Optional[str],
highlight: bool
) -> None: ) -> None:
"""Recursively add node or edge.""" """Recursively add node or edge."""
fillcolor = 'white' fillcolor = "white"
style = '' style = ""
tooltip = None tooltip = None
if highlight: if highlight:
color = 'blue' color = "blue"
penwidth = '3' penwidth = "3"
else: else:
color = 'black' color = "black"
penwidth = '1' penwidth = "1"
if 'split_index' in root: # non-leaf if "split_index" in root: # non-leaf
shape = "rectangle" shape = "rectangle"
l_dec = 'yes' l_dec = "yes"
r_dec = 'no' r_dec = "no"
threshold = root['threshold'] threshold = root["threshold"]
if root['decision_type'] == '<=': if root["decision_type"] == "<=":
operator = "&#8804;" operator = "&#8804;"
elif root['decision_type'] == '==': elif root["decision_type"] == "==":
operator = "=" operator = "="
else: else:
raise ValueError('Invalid decision type in tree model.') raise ValueError("Invalid decision type in tree model.")
name = f"split{root['split_index']}" name = f"split{root['split_index']}"
split_feature = root['split_feature'] split_feature = root["split_feature"]
if feature_names is not None: if feature_names is not None:
label = f"<B>{feature_names[split_feature]}</B> {operator}" label = f"<B>{feature_names[split_feature]}</B> {operator}"
else: else:
label = f"feature <B>{split_feature}</B> {operator} " label = f"feature <B>{split_feature}</B> {operator} "
direction = None direction = None
if example_case is not None: if example_case is not None:
if root['decision_type'] == '==': if root["decision_type"] == "==":
direction = _determine_direction_for_categorical_split( direction = _determine_direction_for_categorical_split(
fval=example_case[split_feature], fval=example_case[split_feature], thresholds=root["threshold"]
thresholds=root['threshold']
) )
else: else:
direction = _determine_direction_for_numeric_split( direction = _determine_direction_for_numeric_split(
fval=example_case[split_feature], fval=example_case[split_feature],
threshold=root['threshold'], threshold=root["threshold"],
missing_type_str=root['missing_type'], missing_type_str=root["missing_type"],
default_left=root['default_left'] default_left=root["default_left"],
) )
if root['decision_type'] == '==': if root["decision_type"] == "==":
category_values = root['threshold'].split('||') category_values = root["threshold"].split("||")
if len(category_values) > max_category_values: if len(category_values) > max_category_values:
tooltip = root['threshold'] tooltip = root["threshold"]
threshold = '||'.join(category_values[:2]) + '||...||' + category_values[-1] threshold = "||".join(category_values[:2]) + "||...||" + category_values[-1]
label += f"<B>{_float2str(threshold, precision)}</B>" label += f"<B>{_float2str(threshold, precision)}</B>"
for info in ['split_gain', 'internal_value', 'internal_weight', "internal_count", "data_percentage"]: for info in ["split_gain", "internal_value", "internal_weight", "internal_count", "data_percentage"]:
if info in show_info: if info in show_info:
output = info.split('_')[-1] output = info.split("_")[-1]
if info in {'split_gain', 'internal_value', 'internal_weight'}: if info in {"split_gain", "internal_value", "internal_weight"}:
label += f"<br/>{_float2str(root[info], precision)} {output}" label += f"<br/>{_float2str(root[info], precision)} {output}"
elif info == 'internal_count': elif info == "internal_count":
label += f"<br/>{output}: {root[info]}" label += f"<br/>{output}: {root[info]}"
elif info == "data_percentage": elif info == "data_percentage":
label += f"<br/>{_float2str(root['internal_count'] / total_count * 100, 2)}% of data" label += f"<br/>{_float2str(root['internal_count'] / total_count * 100, 2)}% of data"
if constraints: if constraints:
if constraints[root['split_feature']] == 1: if constraints[root["split_feature"]] == 1:
fillcolor = "#ddffdd" # light green fillcolor = "#ddffdd" # light green
if constraints[root['split_feature']] == -1: if constraints[root["split_feature"]] == -1:
fillcolor = "#ffdddd" # light red fillcolor = "#ffdddd" # light red
style = "filled" style = "filled"
label = f"<{label}>" label = f"<{label}>"
add( add(
root=root['left_child'], root=root["left_child"],
total_count=total_count, total_count=total_count,
parent=name, parent=name,
decision=l_dec, decision=l_dec,
highlight=highlight and direction == "left" highlight=highlight and direction == "left",
) )
add( add(
root=root['right_child'], root=root["right_child"],
total_count=total_count, total_count=total_count,
parent=name, parent=name,
decision=r_dec, decision=r_dec,
highlight=highlight and direction == "right" highlight=highlight and direction == "right",
) )
else: # leaf else: # leaf
shape = "ellipse" shape = "ellipse"
name = f"leaf{root['leaf_index']}" name = f"leaf{root['leaf_index']}"
label = f"leaf {root['leaf_index']}: " label = f"leaf {root['leaf_index']}: "
label += f"<B>{_float2str(root['leaf_value'], precision)}</B>" label += f"<B>{_float2str(root['leaf_value'], precision)}</B>"
if 'leaf_weight' in show_info: if "leaf_weight" in show_info:
label += f"<br/>{_float2str(root['leaf_weight'], precision)} weight" label += f"<br/>{_float2str(root['leaf_weight'], precision)} weight"
if 'leaf_count' in show_info: if "leaf_count" in show_info:
label += f"<br/>count: {root['leaf_count']}" label += f"<br/>count: {root['leaf_count']}"
if "data_percentage" in show_info: if "data_percentage" in show_info:
label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data" label += f"<br/>{_float2str(root['leaf_count'] / total_count * 100, 2)}% of data"
label = f"<{label}>" label = f"<{label}>"
graph.node(name, label=label, shape=shape, style=style, fillcolor=fillcolor, color=color, penwidth=penwidth, tooltip=tooltip) graph.node(
name,
label=label,
shape=shape,
style=style,
fillcolor=fillcolor,
color=color,
penwidth=penwidth,
tooltip=tooltip,
)
if parent is not None: if parent is not None:
graph.edge(parent, name, decision, color=color, penwidth=penwidth) graph.edge(parent, name, decision, color=color, penwidth=penwidth)
graph = Digraph(**kwargs) graph = Digraph(**kwargs)
rankdir = "LR" if orientation == "horizontal" else "TB" rankdir = "LR" if orientation == "horizontal" else "TB"
graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir) graph.attr("graph", nodesep="0.05", ranksep="0.3", rankdir=rankdir)
if "internal_count" in tree_info['tree_structure']: if "internal_count" in tree_info["tree_structure"]:
add( add(
root=tree_info['tree_structure'], root=tree_info["tree_structure"],
total_count=tree_info['tree_structure']["internal_count"], total_count=tree_info["tree_structure"]["internal_count"],
parent=None, parent=None,
decision=None, decision=None,
highlight=example_case is not None highlight=example_case is not None,
) )
else: else:
raise Exception("Cannot plot trees with no split") raise Exception("Cannot plot trees with no split")
...@@ -610,10 +612,10 @@ def create_tree_digraph( ...@@ -610,10 +612,10 @@ def create_tree_digraph(
tree_index: int = 0, tree_index: int = 0,
show_info: Optional[List[str]] = None, show_info: Optional[List[str]] = None,
precision: Optional[int] = 3, precision: Optional[int] = 3,
orientation: str = 'horizontal', orientation: str = "horizontal",
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None, example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
max_category_values: int = 10, max_category_values: int = 10,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
"""Create a digraph representation of specified tree. """Create a digraph representation of specified tree.
...@@ -689,32 +691,32 @@ def create_tree_digraph( ...@@ -689,32 +691,32 @@ def create_tree_digraph(
if isinstance(booster, LGBMModel): if isinstance(booster, LGBMModel):
booster = booster.booster_ booster = booster.booster_
elif not isinstance(booster, Booster): elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.') raise TypeError("booster must be Booster or LGBMModel.")
model = booster.dump_model() model = booster.dump_model()
tree_infos = model['tree_info'] tree_infos = model["tree_info"]
feature_names = model.get('feature_names', None) feature_names = model.get("feature_names", None)
monotone_constraints = model.get('monotone_constraints', None) monotone_constraints = model.get("monotone_constraints", None)
if tree_index < len(tree_infos): if tree_index < len(tree_infos):
tree_info = tree_infos[tree_index] tree_info = tree_infos[tree_index]
else: else:
raise IndexError('tree_index is out of range.') raise IndexError("tree_index is out of range.")
if show_info is None: if show_info is None:
show_info = [] show_info = []
if example_case is not None: if example_case is not None:
if not isinstance(example_case, (np.ndarray, pd_DataFrame)) or example_case.ndim != 2: if not isinstance(example_case, (np.ndarray, pd_DataFrame)) or example_case.ndim != 2:
raise ValueError('example_case must be a numpy 2-D array or a pandas DataFrame') raise ValueError("example_case must be a numpy 2-D array or a pandas DataFrame")
if example_case.shape[0] != 1: if example_case.shape[0] != 1:
raise ValueError('example_case must have a single row.') raise ValueError("example_case must have a single row.")
if isinstance(example_case, pd_DataFrame): if isinstance(example_case, pd_DataFrame):
example_case = _data_from_pandas( example_case = _data_from_pandas(
data=example_case, data=example_case,
feature_name="auto", feature_name="auto",
categorical_feature="auto", categorical_feature="auto",
pandas_categorical=booster.pandas_categorical pandas_categorical=booster.pandas_categorical,
)[0] )[0]
example_case = example_case[0] example_case = example_case[0]
...@@ -727,7 +729,7 @@ def create_tree_digraph( ...@@ -727,7 +729,7 @@ def create_tree_digraph(
constraints=monotone_constraints, constraints=monotone_constraints,
example_case=example_case, example_case=example_case,
max_category_values=max_category_values, max_category_values=max_category_values,
**kwargs **kwargs,
) )
...@@ -739,9 +741,9 @@ def plot_tree( ...@@ -739,9 +741,9 @@ def plot_tree(
dpi: Optional[int] = None, dpi: Optional[int] = None,
show_info: Optional[List[str]] = None, show_info: Optional[List[str]] = None,
precision: Optional[int] = 3, precision: Optional[int] = 3,
orientation: str = 'horizontal', orientation: str = "horizontal",
example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None, example_case: Optional[Union[np.ndarray, pd_DataFrame]] = None,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
"""Plot specified tree. """Plot specified tree.
...@@ -807,22 +809,28 @@ def plot_tree( ...@@ -807,22 +809,28 @@ def plot_tree(
import matplotlib.image import matplotlib.image
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
else: else:
raise ImportError('You must install matplotlib and restart your session to plot tree.') raise ImportError("You must install matplotlib and restart your session to plot tree.")
if ax is None: if ax is None:
if figsize is not None: if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize') _check_not_tuple_of_2_elements(figsize, "figsize")
_, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) _, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
graph = create_tree_digraph(booster=booster, tree_index=tree_index, graph = create_tree_digraph(
show_info=show_info, precision=precision, booster=booster,
orientation=orientation, example_case=example_case, **kwargs) tree_index=tree_index,
show_info=show_info,
precision=precision,
orientation=orientation,
example_case=example_case,
**kwargs,
)
s = BytesIO() s = BytesIO()
s.write(graph.pipe(format='png')) s.write(graph.pipe(format="png"))
s.seek(0) s.seek(0)
img = matplotlib.image.imread(s) img = matplotlib.image.imread(s)
ax.imshow(img) ax.imshow(img)
ax.axis('off') ax.axis("off")
return ax return ax
...@@ -46,10 +46,10 @@ from .compat import ( ...@@ -46,10 +46,10 @@ from .compat import (
from .engine import train from .engine import train
__all__ = [ __all__ = [
'LGBMClassifier', "LGBMClassifier",
'LGBMModel', "LGBMModel",
'LGBMRanker', "LGBMRanker",
'LGBMRegressor', "LGBMRegressor",
] ]
_LGBM_ScikitMatrixLike = Union[ _LGBM_ScikitMatrixLike = Union[
...@@ -57,58 +57,58 @@ _LGBM_ScikitMatrixLike = Union[ ...@@ -57,58 +57,58 @@ _LGBM_ScikitMatrixLike = Union[
List[Union[List[float], List[int]]], List[Union[List[float], List[int]]],
np.ndarray, np.ndarray,
pd_DataFrame, pd_DataFrame,
scipy.sparse.spmatrix scipy.sparse.spmatrix,
] ]
_LGBM_ScikitCustomObjectiveFunction = Union[ _LGBM_ScikitCustomObjectiveFunction = Union[
# f(labels, preds) # f(labels, preds)
Callable[ Callable[
[Optional[np.ndarray], np.ndarray], [Optional[np.ndarray], np.ndarray],
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray],
], ],
# f(labels, preds, weights) # f(labels, preds, weights)
Callable[ Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray],
], ],
# f(labels, preds, weights, group) # f(labels, preds, weights, group)
Callable[ Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
Tuple[np.ndarray, np.ndarray] Tuple[np.ndarray, np.ndarray],
], ],
] ]
_LGBM_ScikitCustomEvalFunction = Union[ _LGBM_ScikitCustomEvalFunction = Union[
# f(labels, preds) # f(labels, preds)
Callable[ Callable[
[Optional[np.ndarray], np.ndarray], [Optional[np.ndarray], np.ndarray],
_LGBM_EvalFunctionResultType _LGBM_EvalFunctionResultType,
], ],
Callable[ Callable[
[Optional[np.ndarray], np.ndarray], [Optional[np.ndarray], np.ndarray],
List[_LGBM_EvalFunctionResultType] List[_LGBM_EvalFunctionResultType],
], ],
# f(labels, preds, weights) # f(labels, preds, weights)
Callable[ Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
_LGBM_EvalFunctionResultType _LGBM_EvalFunctionResultType,
], ],
Callable[ Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray]],
List[_LGBM_EvalFunctionResultType] List[_LGBM_EvalFunctionResultType],
], ],
# f(labels, preds, weights, group) # f(labels, preds, weights, group)
Callable[ Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
_LGBM_EvalFunctionResultType _LGBM_EvalFunctionResultType,
], ],
Callable[ Callable[
[Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], [Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]],
List[_LGBM_EvalFunctionResultType] List[_LGBM_EvalFunctionResultType],
] ],
] ]
_LGBM_ScikitEvalMetricType = Union[ _LGBM_ScikitEvalMetricType = Union[
str, str,
_LGBM_ScikitCustomEvalFunction, _LGBM_ScikitCustomEvalFunction,
List[Union[str, _LGBM_ScikitCustomEvalFunction]] List[Union[str, _LGBM_ScikitCustomEvalFunction]],
] ]
_LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType] _LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType]
...@@ -119,7 +119,7 @@ def _get_group_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray ...@@ -119,7 +119,7 @@ def _get_group_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarray
"Estimators in lightgbm.sklearn should only retrieve query groups from a constructed Dataset. " "Estimators in lightgbm.sklearn should only retrieve query groups from a constructed Dataset. "
"If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues."
) )
assert (group is None or isinstance(group, np.ndarray)), error_msg assert group is None or isinstance(group, np.ndarray), error_msg
return group return group
...@@ -139,7 +139,7 @@ def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarra ...@@ -139,7 +139,7 @@ def _get_weight_from_constructed_dataset(dataset: Dataset) -> Optional[np.ndarra
"Estimators in lightgbm.sklearn should only retrieve weights from a constructed Dataset. " "Estimators in lightgbm.sklearn should only retrieve weights from a constructed Dataset. "
"If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues." "If you're seeing this message, it's a bug in lightgbm. Please report it at https://github.com/microsoft/LightGBM/issues."
) )
assert (weight is None or isinstance(weight, np.ndarray)), error_msg assert weight is None or isinstance(weight, np.ndarray), error_msg
return weight return weight
...@@ -189,7 +189,11 @@ class _ObjectiveFunctionWrapper: ...@@ -189,7 +189,11 @@ class _ObjectiveFunctionWrapper:
""" """
self.func = func self.func = func
def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np.ndarray]: def __call__(
self,
preds: np.ndarray,
dataset: Dataset,
) -> Tuple[np.ndarray, np.ndarray]:
"""Call passed function with appropriate arguments. """Call passed function with appropriate arguments.
Parameters Parameters
...@@ -271,7 +275,7 @@ class _EvalFunctionWrapper: ...@@ -271,7 +275,7 @@ class _EvalFunctionWrapper:
def __call__( def __call__(
self, self,
preds: np.ndarray, preds: np.ndarray,
dataset: Dataset dataset: Dataset,
) -> Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]: ) -> Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]]:
"""Call passed function with appropriate arguments. """Call passed function with appropriate arguments.
...@@ -310,8 +314,7 @@ class _EvalFunctionWrapper: ...@@ -310,8 +314,7 @@ class _EvalFunctionWrapper:
# documentation templates for LGBMModel methods are shared between the classes in # documentation templates for LGBMModel methods are shared between the classes in
# this module and those in the ``dask`` module # this module and those in the ``dask`` module
_lgbmmodel_doc_fit = ( _lgbmmodel_doc_fit = """
"""
Build a gradient boosting model from the training set (X, y). Build a gradient boosting model from the training set (X, y).
Parameters Parameters
...@@ -372,7 +375,6 @@ _lgbmmodel_doc_fit = ( ...@@ -372,7 +375,6 @@ _lgbmmodel_doc_fit = (
self : LGBMModel self : LGBMModel
Returns self. Returns self.
""" """
)
_lgbmmodel_doc_custom_eval_note = """ _lgbmmodel_doc_custom_eval_note = """
Note Note
...@@ -405,8 +407,7 @@ _lgbmmodel_doc_custom_eval_note = """ ...@@ -405,8 +407,7 @@ _lgbmmodel_doc_custom_eval_note = """
Is eval result higher better, e.g. AUC is ``is_higher_better``. Is eval result higher better, e.g. AUC is ``is_higher_better``.
""" """
_lgbmmodel_doc_predict = ( _lgbmmodel_doc_predict = """
"""
{description} {description}
Parameters Parameters
...@@ -451,7 +452,6 @@ _lgbmmodel_doc_predict = ( ...@@ -451,7 +452,6 @@ _lgbmmodel_doc_predict = (
X_SHAP_values : {X_SHAP_values_shape} X_SHAP_values : {X_SHAP_values_shape}
If ``pred_contrib=True``, the feature contributions for each sample. If ``pred_contrib=True``, the feature contributions for each sample.
""" """
)
class LGBMModel(_LGBMModelBase): class LGBMModel(_LGBMModelBase):
...@@ -459,7 +459,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -459,7 +459,7 @@ class LGBMModel(_LGBMModelBase):
def __init__( def __init__(
self, self,
boosting_type: str = 'gbdt', boosting_type: str = "gbdt",
num_leaves: int = 31, num_leaves: int = 31,
max_depth: int = -1, max_depth: int = -1,
learning_rate: float = 0.1, learning_rate: float = 0.1,
...@@ -467,18 +467,18 @@ class LGBMModel(_LGBMModelBase): ...@@ -467,18 +467,18 @@ class LGBMModel(_LGBMModelBase):
subsample_for_bin: int = 200000, subsample_for_bin: int = 200000,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None, class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0., min_split_gain: float = 0.0,
min_child_weight: float = 1e-3, min_child_weight: float = 1e-3,
min_child_samples: int = 20, min_child_samples: int = 20,
subsample: float = 1., subsample: float = 1.0,
subsample_freq: int = 0, subsample_freq: int = 0,
colsample_bytree: float = 1., colsample_bytree: float = 1.0,
reg_alpha: float = 0., reg_alpha: float = 0.0,
reg_lambda: float = 0., reg_lambda: float = 0.0,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None, random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None, n_jobs: Optional[int] = None,
importance_type: str = 'split', importance_type: str = "split",
**kwargs **kwargs,
): ):
r"""Construct a gradient boosting model. r"""Construct a gradient boosting model.
...@@ -598,8 +598,10 @@ class LGBMModel(_LGBMModelBase): ...@@ -598,8 +598,10 @@ class LGBMModel(_LGBMModelBase):
and grad and hess should be returned in the same format. and grad and hess should be returned in the same format.
""" """
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise LightGBMError('scikit-learn is required for lightgbm.sklearn. ' raise LightGBMError(
'You must install scikit-learn and restart your session to use this module.') "scikit-learn is required for lightgbm.sklearn. "
"You must install scikit-learn and restart your session to use this module."
)
self.boosting_type = boosting_type self.boosting_type = boosting_type
self.objective = objective self.objective = objective
...@@ -636,14 +638,13 @@ class LGBMModel(_LGBMModelBase): ...@@ -636,14 +638,13 @@ class LGBMModel(_LGBMModelBase):
def _more_tags(self) -> Dict[str, Any]: def _more_tags(self) -> Dict[str, Any]:
return { return {
'allow_nan': True, "allow_nan": True,
'X_types': ['2darray', 'sparse', '1dlabels'], "X_types": ["2darray", "sparse", "1dlabels"],
'_xfail_checks': { "_xfail_checks": {
'check_no_attributes_set_in_init': "check_no_attributes_set_in_init": "scikit-learn incorrectly asserts that private attributes "
'scikit-learn incorrectly asserts that private attributes ' "cannot be set in __init__: "
'cannot be set in __init__: ' "(see https://github.com/microsoft/LightGBM/issues/2628)"
'(see https://github.com/microsoft/LightGBM/issues/2628)' },
}
} }
def __sklearn_is_fitted__(self) -> bool: def __sklearn_is_fitted__(self) -> bool:
...@@ -703,8 +704,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -703,8 +704,8 @@ class LGBMModel(_LGBMModelBase):
assert stage in {"fit", "predict"} assert stage in {"fit", "predict"}
params = self.get_params() params = self.get_params()
params.pop('objective', None) params.pop("objective", None)
for alias in _ConfigAliases.get('objective'): for alias in _ConfigAliases.get("objective"):
if alias in params: if alias in params:
obj = params.pop(alias) obj = params.pop(alias)
_log_warning(f"Found '{alias}' in params. Will use it instead of 'objective' argument") _log_warning(f"Found '{alias}' in params. Will use it instead of 'objective' argument")
...@@ -725,33 +726,31 @@ class LGBMModel(_LGBMModelBase): ...@@ -725,33 +726,31 @@ class LGBMModel(_LGBMModelBase):
raise ValueError("Unknown LGBMModel type.") raise ValueError("Unknown LGBMModel type.")
if callable(self._objective): if callable(self._objective):
if stage == "fit": if stage == "fit":
params['objective'] = _ObjectiveFunctionWrapper(self._objective) params["objective"] = _ObjectiveFunctionWrapper(self._objective)
else: else:
params['objective'] = 'None' params["objective"] = "None"
else: else:
params['objective'] = self._objective params["objective"] = self._objective
params.pop('importance_type', None) params.pop("importance_type", None)
params.pop('n_estimators', None) params.pop("n_estimators", None)
params.pop('class_weight', None) params.pop("class_weight", None)
if isinstance(params['random_state'], np.random.RandomState): if isinstance(params["random_state"], np.random.RandomState):
params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max) params["random_state"] = params["random_state"].randint(np.iinfo(np.int32).max)
elif isinstance(params['random_state'], np_random_Generator): elif isinstance(params["random_state"], np_random_Generator):
params['random_state'] = int( params["random_state"] = int(params["random_state"].integers(np.iinfo(np.int32).max))
params['random_state'].integers(np.iinfo(np.int32).max)
)
if self._n_classes > 2: if self._n_classes > 2:
for alias in _ConfigAliases.get('num_class'): for alias in _ConfigAliases.get("num_class"):
params.pop(alias, None) params.pop(alias, None)
params['num_class'] = self._n_classes params["num_class"] = self._n_classes
if hasattr(self, '_eval_at'): if hasattr(self, "_eval_at"):
eval_at = self._eval_at eval_at = self._eval_at
for alias in _ConfigAliases.get('eval_at'): for alias in _ConfigAliases.get("eval_at"):
if alias in params: if alias in params:
_log_warning(f"Found '{alias}' in params. Will use it instead of 'eval_at' argument") _log_warning(f"Found '{alias}' in params. Will use it instead of 'eval_at' argument")
eval_at = params.pop(alias) eval_at = params.pop(alias)
params['eval_at'] = eval_at params["eval_at"] = eval_at
# register default metric for consistency with callable eval_metric case # register default metric for consistency with callable eval_metric case
original_metric = self._objective if isinstance(self._objective, str) else None original_metric = self._objective if isinstance(self._objective, str) else None
...@@ -809,10 +808,10 @@ class LGBMModel(_LGBMModelBase): ...@@ -809,10 +808,10 @@ class LGBMModel(_LGBMModelBase):
eval_init_score: Optional[List[_LGBM_InitScoreType]] = None, eval_init_score: Optional[List[_LGBM_InitScoreType]] = None,
eval_group: Optional[List[_LGBM_GroupType]] = None, eval_group: Optional[List[_LGBM_GroupType]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name: _LGBM_FeatureNameConfiguration = 'auto', feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
callbacks: Optional[List[Callable]] = None, callbacks: Optional[List[Callable]] = None,
init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None,
) -> "LGBMModel": ) -> "LGBMModel":
"""Docstring is set after definition, using a template.""" """Docstring is set after definition, using a template."""
params = self._process_params(stage="fit") params = self._process_params(stage="fit")
...@@ -832,9 +831,9 @@ class LGBMModel(_LGBMModelBase): ...@@ -832,9 +831,9 @@ class LGBMModel(_LGBMModelBase):
eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)] eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)]
# concatenate metric from params (or default if not provided in params) and eval_metric # concatenate metric from params (or default if not provided in params) and eval_metric
params['metric'] = [params['metric']] if isinstance(params['metric'], (str, type(None))) else params['metric'] params["metric"] = [params["metric"]] if isinstance(params["metric"], (str, type(None))) else params["metric"]
params['metric'] = [e for e in eval_metrics_builtin if e not in params['metric']] + params['metric'] params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"]
params['metric'] = [metric for metric in params['metric'] if metric is not None] params["metric"] = [metric for metric in params["metric"] if metric is not None]
if not isinstance(X, (pd_DataFrame, dt_DataTable)): if not isinstance(X, (pd_DataFrame, dt_DataTable)):
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2) _X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
...@@ -856,9 +855,15 @@ class LGBMModel(_LGBMModelBase): ...@@ -856,9 +855,15 @@ class LGBMModel(_LGBMModelBase):
# copy for consistency # copy for consistency
self._n_features_in = self._n_features self._n_features_in = self._n_features
train_set = Dataset(data=_X, label=_y, weight=sample_weight, group=group, train_set = Dataset(
init_score=init_score, categorical_feature=categorical_feature, data=_X,
params=params) label=_y,
weight=sample_weight,
group=group,
init_score=init_score,
categorical_feature=categorical_feature,
params=params,
)
valid_sets: List[Dataset] = [] valid_sets: List[Dataset] = []
if eval_set is not None: if eval_set is not None:
...@@ -880,8 +885,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -880,8 +885,8 @@ class LGBMModel(_LGBMModelBase):
if valid_data[0] is X and valid_data[1] is y: if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set valid_set = train_set
else: else:
valid_weight = _get_meta_data(eval_sample_weight, 'eval_sample_weight', i) valid_weight = _get_meta_data(eval_sample_weight, "eval_sample_weight", i)
valid_class_weight = _get_meta_data(eval_class_weight, 'eval_class_weight', i) valid_class_weight = _get_meta_data(eval_class_weight, "eval_class_weight", i)
if valid_class_weight is not None: if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None: if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()} valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
...@@ -890,11 +895,17 @@ class LGBMModel(_LGBMModelBase): ...@@ -890,11 +895,17 @@ class LGBMModel(_LGBMModelBase):
valid_weight = valid_class_sample_weight valid_weight = valid_class_sample_weight
else: else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight) valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _get_meta_data(eval_init_score, 'eval_init_score', i) valid_init_score = _get_meta_data(eval_init_score, "eval_init_score", i)
valid_group = _get_meta_data(eval_group, 'eval_group', i) valid_group = _get_meta_data(eval_group, "eval_group", i)
valid_set = Dataset(data=valid_data[0], label=valid_data[1], weight=valid_weight, valid_set = Dataset(
group=valid_group, init_score=valid_init_score, data=valid_data[0],
categorical_feature='auto', params=params) label=valid_data[1],
weight=valid_weight,
group=valid_group,
init_score=valid_init_score,
categorical_feature="auto",
params=params,
)
valid_sets.append(valid_set) valid_sets.append(valid_set)
...@@ -918,7 +929,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -918,7 +929,7 @@ class LGBMModel(_LGBMModelBase):
feval=eval_metrics_callable, # type: ignore[arg-type] feval=eval_metrics_callable, # type: ignore[arg-type]
init_model=init_model, init_model=init_model,
feature_name=feature_name, feature_name=feature_name,
callbacks=callbacks callbacks=callbacks,
) )
self._evals_result = evals_result self._evals_result = evals_result
...@@ -932,16 +943,20 @@ class LGBMModel(_LGBMModelBase): ...@@ -932,16 +943,20 @@ class LGBMModel(_LGBMModelBase):
del train_set, valid_sets del train_set, valid_sets
return self return self
fit.__doc__ = _lgbmmodel_doc_fit.format( fit.__doc__ = (
X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]", _lgbmmodel_doc_fit.format(
y_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples]", X_shape="numpy array, pandas DataFrame, H2O DataTable's Frame , scipy.sparse, list of lists of int or float of shape = [n_samples, n_features]",
sample_weight_shape="numpy array, pandas Series, list of int or float of shape = [n_samples] or None, optional (default=None)", y_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples]",
init_score_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) or shape = [n_samples, n_classes] (for multi-class task) or None, optional (default=None)", sample_weight_shape="numpy array, pandas Series, list of int or float of shape = [n_samples] or None, optional (default=None)",
group_shape="numpy array, pandas Series, list of int or float, or None, optional (default=None)", init_score_shape="numpy array, pandas DataFrame, pandas Series, list of int or float of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task) or shape = [n_samples, n_classes] (for multi-class task) or None, optional (default=None)",
eval_sample_weight_shape="list of array (same types as ``sample_weight`` supports), or None, optional (default=None)", group_shape="numpy array, pandas Series, list of int or float, or None, optional (default=None)",
eval_init_score_shape="list of array (same types as ``init_score`` supports), or None, optional (default=None)", eval_sample_weight_shape="list of array (same types as ``sample_weight`` supports), or None, optional (default=None)",
eval_group_shape="list of array (same types as ``group`` supports), or None, optional (default=None)" eval_init_score_shape="list of array (same types as ``init_score`` supports), or None, optional (default=None)",
) + "\n\n" + _lgbmmodel_doc_custom_eval_note eval_group_shape="list of array (same types as ``group`` supports), or None, optional (default=None)",
)
+ "\n\n"
+ _lgbmmodel_doc_custom_eval_note
)
def predict( def predict(
self, self,
...@@ -952,7 +967,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -952,7 +967,7 @@ class LGBMModel(_LGBMModelBase):
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any,
): ):
"""Docstring is set after definition, using a template.""" """Docstring is set after definition, using a template."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
...@@ -961,9 +976,11 @@ class LGBMModel(_LGBMModelBase): ...@@ -961,9 +976,11 @@ class LGBMModel(_LGBMModelBase):
X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False) X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)
n_features = X.shape[1] n_features = X.shape[1]
if self._n_features != n_features: if self._n_features != n_features:
raise ValueError("Number of features of the model must " raise ValueError(
f"match the input. Model n_features_ is {self._n_features} and " "Number of features of the model must "
f"input n_features is {n_features}") f"match the input. Model n_features_ is {self._n_features} and "
f"input n_features is {n_features}"
)
# retrive original params that possibly can be used in both training and prediction # retrive original params that possibly can be used in both training and prediction
# and then overwrite them (considering aliases) with params that were passed directly in prediction # and then overwrite them (considering aliases) with params that were passed directly in prediction
predict_params = self._process_params(stage="predict") predict_params = self._process_params(stage="predict")
...@@ -975,7 +992,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -975,7 +992,7 @@ class LGBMModel(_LGBMModelBase):
"num_iteration", "num_iteration",
"pred_leaf", "pred_leaf",
"pred_contrib", "pred_contrib",
*kwargs.keys() *kwargs.keys(),
): ):
predict_params.pop(alias, None) predict_params.pop(alias, None)
predict_params.update(kwargs) predict_params.update(kwargs)
...@@ -986,9 +1003,14 @@ class LGBMModel(_LGBMModelBase): ...@@ -986,9 +1003,14 @@ class LGBMModel(_LGBMModelBase):
predict_params["num_threads"] = self._process_n_jobs(predict_params["num_threads"]) predict_params["num_threads"] = self._process_n_jobs(predict_params["num_threads"])
return self._Booster.predict( # type: ignore[union-attr] return self._Booster.predict( # type: ignore[union-attr]
X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration, X,
pred_leaf=pred_leaf, pred_contrib=pred_contrib, validate_features=validate_features, raw_score=raw_score,
**predict_params start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**predict_params,
) )
predict.__doc__ = _lgbmmodel_doc_predict.format( predict.__doc__ = _lgbmmodel_doc_predict.format(
...@@ -997,42 +1019,44 @@ class LGBMModel(_LGBMModelBase): ...@@ -997,42 +1019,44 @@ class LGBMModel(_LGBMModelBase):
output_name="predicted_result", output_name="predicted_result",
predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]", predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects" X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects",
) )
@property @property
def n_features_(self) -> int: def n_features_(self) -> int:
""":obj:`int`: The number of features of fitted model.""" """:obj:`int`: The number of features of fitted model."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_features found. Need to call fit beforehand.') raise LGBMNotFittedError("No n_features found. Need to call fit beforehand.")
return self._n_features return self._n_features
@property @property
def n_features_in_(self) -> int: def n_features_in_(self) -> int:
""":obj:`int`: The number of features of fitted model.""" """:obj:`int`: The number of features of fitted model."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_features_in found. Need to call fit beforehand.') raise LGBMNotFittedError("No n_features_in found. Need to call fit beforehand.")
return self._n_features_in return self._n_features_in
@property @property
def best_score_(self) -> _LGBM_BoosterBestScoreType: def best_score_(self) -> _LGBM_BoosterBestScoreType:
""":obj:`dict`: The best score of fitted model.""" """:obj:`dict`: The best score of fitted model."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No best_score found. Need to call fit beforehand.') raise LGBMNotFittedError("No best_score found. Need to call fit beforehand.")
return self._best_score return self._best_score
@property @property
def best_iteration_(self) -> int: def best_iteration_(self) -> int:
""":obj:`int`: The best iteration of fitted model if ``early_stopping()`` callback has been specified.""" """:obj:`int`: The best iteration of fitted model if ``early_stopping()`` callback has been specified."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.') raise LGBMNotFittedError(
"No best_iteration found. Need to call fit with early_stopping callback beforehand."
)
return self._best_iteration return self._best_iteration
@property @property
def objective_(self) -> Union[str, _LGBM_ScikitCustomObjectiveFunction]: def objective_(self) -> Union[str, _LGBM_ScikitCustomObjectiveFunction]:
""":obj:`str` or :obj:`callable`: The concrete objective used while fitting this model.""" """:obj:`str` or :obj:`callable`: The concrete objective used while fitting this model."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No objective found. Need to call fit beforehand.') raise LGBMNotFittedError("No objective found. Need to call fit beforehand.")
return self._objective # type: ignore[return-value] return self._objective # type: ignore[return-value]
@property @property
...@@ -1041,11 +1065,11 @@ class LGBMModel(_LGBMModelBase): ...@@ -1041,11 +1065,11 @@ class LGBMModel(_LGBMModelBase):
This might be less than parameter ``n_estimators`` if early stopping was enabled or This might be less than parameter ``n_estimators`` if early stopping was enabled or
if boosting stopped early due to limits on complexity like ``min_gain_to_split``. if boosting stopped early due to limits on complexity like ``min_gain_to_split``.
.. versionadded:: 4.0.0 .. versionadded:: 4.0.0
""" """
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_estimators found. Need to call fit beforehand.') raise LGBMNotFittedError("No n_estimators found. Need to call fit beforehand.")
return self._Booster.current_iteration() # type: ignore return self._Booster.current_iteration() # type: ignore
@property @property
...@@ -1054,25 +1078,25 @@ class LGBMModel(_LGBMModelBase): ...@@ -1054,25 +1078,25 @@ class LGBMModel(_LGBMModelBase):
This might be less than parameter ``n_estimators`` if early stopping was enabled or This might be less than parameter ``n_estimators`` if early stopping was enabled or
if boosting stopped early due to limits on complexity like ``min_gain_to_split``. if boosting stopped early due to limits on complexity like ``min_gain_to_split``.
.. versionadded:: 4.0.0 .. versionadded:: 4.0.0
""" """
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No n_iter found. Need to call fit beforehand.') raise LGBMNotFittedError("No n_iter found. Need to call fit beforehand.")
return self._Booster.current_iteration() # type: ignore return self._Booster.current_iteration() # type: ignore
@property @property
def booster_(self) -> Booster: def booster_(self) -> Booster:
"""Booster: The underlying Booster of this model.""" """Booster: The underlying Booster of this model."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No booster found. Need to call fit beforehand.') raise LGBMNotFittedError("No booster found. Need to call fit beforehand.")
return self._Booster # type: ignore[return-value] return self._Booster # type: ignore[return-value]
@property @property
def evals_result_(self) -> _EvalResultDict: def evals_result_(self) -> _EvalResultDict:
""":obj:`dict`: The evaluation results if validation sets have been specified.""" """:obj:`dict`: The evaluation results if validation sets have been specified."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No results found. Need to call fit with eval_set beforehand.') raise LGBMNotFittedError("No results found. Need to call fit with eval_set beforehand.")
return self._evals_result return self._evals_result
@property @property
...@@ -1085,14 +1109,14 @@ class LGBMModel(_LGBMModelBase): ...@@ -1085,14 +1109,14 @@ class LGBMModel(_LGBMModelBase):
to configure the type of importance values to be extracted. to configure the type of importance values to be extracted.
""" """
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No feature_importances found. Need to call fit beforehand.') raise LGBMNotFittedError("No feature_importances found. Need to call fit beforehand.")
return self._Booster.feature_importance(importance_type=self.importance_type) # type: ignore[union-attr] return self._Booster.feature_importance(importance_type=self.importance_type) # type: ignore[union-attr]
@property @property
def feature_name_(self) -> List[str]: def feature_name_(self) -> List[str]:
""":obj:`list` of shape = [n_features]: The names of features.""" """:obj:`list` of shape = [n_features]: The names of features."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No feature_name found. Need to call fit beforehand.') raise LGBMNotFittedError("No feature_name found. Need to call fit beforehand.")
return self._Booster.feature_name() # type: ignore[union-attr] return self._Booster.feature_name() # type: ignore[union-attr]
...@@ -1110,10 +1134,10 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel): ...@@ -1110,10 +1134,10 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
eval_sample_weight: Optional[List[_LGBM_WeightType]] = None, eval_sample_weight: Optional[List[_LGBM_WeightType]] = None,
eval_init_score: Optional[List[_LGBM_InitScoreType]] = None, eval_init_score: Optional[List[_LGBM_InitScoreType]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name: _LGBM_FeatureNameConfiguration = 'auto', feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
callbacks: Optional[List[Callable]] = None, callbacks: Optional[List[Callable]] = None,
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None,
) -> "LGBMRegressor": ) -> "LGBMRegressor":
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
super().fit( super().fit(
...@@ -1129,17 +1153,17 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel): ...@@ -1129,17 +1153,17 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
feature_name=feature_name, feature_name=feature_name,
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
callbacks=callbacks, callbacks=callbacks,
init_model=init_model init_model=init_model,
) )
return self return self
_base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMRegressor") # type: ignore _base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMRegressor") # type: ignore
_base_doc = (_base_doc[:_base_doc.find('group :')] # type: ignore _base_doc = (
+ _base_doc[_base_doc.find('eval_set :'):]) # type: ignore _base_doc[: _base_doc.find("group :")] # type: ignore
_base_doc = (_base_doc[:_base_doc.find('eval_class_weight :')] + _base_doc[_base_doc.find("eval_set :") :]
+ _base_doc[_base_doc.find('eval_init_score :'):]) ) # type: ignore
fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')] _base_doc = _base_doc[: _base_doc.find("eval_class_weight :")] + _base_doc[_base_doc.find("eval_init_score :") :]
+ _base_doc[_base_doc.find('eval_metric :'):]) fit.__doc__ = _base_doc[: _base_doc.find("eval_group :")] + _base_doc[_base_doc.find("eval_metric :") :]
class LGBMClassifier(_LGBMClassifierBase, LGBMModel): class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
...@@ -1157,10 +1181,10 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1157,10 +1181,10 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
eval_class_weight: Optional[List[float]] = None, eval_class_weight: Optional[List[float]] = None,
eval_init_score: Optional[List[_LGBM_InitScoreType]] = None, eval_init_score: Optional[List[_LGBM_InitScoreType]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
feature_name: _LGBM_FeatureNameConfiguration = 'auto', feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
callbacks: Optional[List[Callable]] = None, callbacks: Optional[List[Callable]] = None,
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None,
) -> "LGBMClassifier": ) -> "LGBMClassifier":
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y) _LGBMAssertAllFinite(y)
...@@ -1187,16 +1211,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1187,16 +1211,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
eval_metric_list = [] eval_metric_list = []
if self._n_classes > 2: if self._n_classes > 2:
for index, metric in enumerate(eval_metric_list): for index, metric in enumerate(eval_metric_list):
if metric in {'logloss', 'binary_logloss'}: if metric in {"logloss", "binary_logloss"}:
eval_metric_list[index] = "multi_logloss" eval_metric_list[index] = "multi_logloss"
elif metric in {'error', 'binary_error'}: elif metric in {"error", "binary_error"}:
eval_metric_list[index] = "multi_error" eval_metric_list[index] = "multi_error"
else: else:
for index, metric in enumerate(eval_metric_list): for index, metric in enumerate(eval_metric_list):
if metric in {'logloss', 'multi_logloss'}: if metric in {"logloss", "multi_logloss"}:
eval_metric_list[index] = 'binary_logloss' eval_metric_list[index] = "binary_logloss"
elif metric in {'error', 'multi_error'}: elif metric in {"error", "multi_error"}:
eval_metric_list[index] = 'binary_error' eval_metric_list[index] = "binary_error"
eval_metric = eval_metric_list eval_metric = eval_metric_list
# do not modify args, as it causes errors in model selection tools # do not modify args, as it causes errors in model selection tools
...@@ -1225,15 +1249,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1225,15 +1249,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
feature_name=feature_name, feature_name=feature_name,
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
callbacks=callbacks, callbacks=callbacks,
init_model=init_model init_model=init_model,
) )
return self return self
_base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMClassifier") # type: ignore _base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMClassifier") # type: ignore
_base_doc = (_base_doc[:_base_doc.find('group :')] # type: ignore _base_doc = (
+ _base_doc[_base_doc.find('eval_set :'):]) # type: ignore _base_doc[: _base_doc.find("group :")] # type: ignore
fit.__doc__ = (_base_doc[:_base_doc.find('eval_group :')] + _base_doc[_base_doc.find("eval_set :") :]
+ _base_doc[_base_doc.find('eval_metric :'):]) ) # type: ignore
fit.__doc__ = _base_doc[: _base_doc.find("eval_group :")] + _base_doc[_base_doc.find("eval_metric :") :]
def predict( def predict(
self, self,
...@@ -1244,7 +1269,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1244,7 +1269,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any,
): ):
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
result = self.predict_proba( result = self.predict_proba(
...@@ -1255,7 +1280,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1255,7 +1280,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
validate_features=validate_features, validate_features=validate_features,
**kwargs **kwargs,
) )
if callable(self._objective) or raw_score or pred_leaf or pred_contrib: if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result return result
...@@ -1274,7 +1299,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1274,7 +1299,7 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
pred_leaf: bool = False, pred_leaf: bool = False,
pred_contrib: bool = False, pred_contrib: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any,
): ):
"""Docstring is set after definition, using a template.""" """Docstring is set after definition, using a template."""
result = super().predict( result = super().predict(
...@@ -1285,17 +1310,19 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1285,17 +1310,19 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contrib=pred_contrib, pred_contrib=pred_contrib,
validate_features=validate_features, validate_features=validate_features,
**kwargs **kwargs,
) )
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib): if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
_log_warning("Cannot compute class probabilities or labels " _log_warning(
"due to the usage of customized objective function.\n" "Cannot compute class probabilities or labels "
"Returning raw scores instead.") "due to the usage of customized objective function.\n"
"Returning raw scores instead."
)
return result return result
elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator] elif self._n_classes > 2 or raw_score or pred_leaf or pred_contrib: # type: ignore [operator]
return result return result
else: else:
return np.vstack((1. - result, result)).transpose() return np.vstack((1.0 - result, result)).transpose()
predict_proba.__doc__ = _lgbmmodel_doc_predict.format( predict_proba.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted probability for each class for each sample.", description="Return the predicted probability for each class for each sample.",
...@@ -1303,21 +1330,21 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1303,21 +1330,21 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
output_name="predicted_probability", output_name="predicted_probability",
predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]", predicted_result_shape="array-like of shape = [n_samples] or shape = [n_samples, n_classes]",
X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]", X_leaves_shape="array-like of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]",
X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects" X_SHAP_values_shape="array-like of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects",
) )
@property @property
def classes_(self) -> np.ndarray: def classes_(self) -> np.ndarray:
""":obj:`array` of shape = [n_classes]: The class label array.""" """:obj:`array` of shape = [n_classes]: The class label array."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No classes found. Need to call fit beforehand.') raise LGBMNotFittedError("No classes found. Need to call fit beforehand.")
return self._classes # type: ignore[return-value] return self._classes # type: ignore[return-value]
@property @property
def n_classes_(self) -> int: def n_classes_(self) -> int:
""":obj:`int`: The number of classes.""" """:obj:`int`: The number of classes."""
if not self.__sklearn_is_fitted__(): if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No classes found. Need to call fit beforehand.') raise LGBMNotFittedError("No classes found. Need to call fit beforehand.")
return self._n_classes return self._n_classes
...@@ -1345,10 +1372,10 @@ class LGBMRanker(LGBMModel): ...@@ -1345,10 +1372,10 @@ class LGBMRanker(LGBMModel):
eval_group: Optional[List[_LGBM_GroupType]] = None, eval_group: Optional[List[_LGBM_GroupType]] = None,
eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None, eval_metric: Optional[_LGBM_ScikitEvalMetricType] = None,
eval_at: Union[List[int], Tuple[int, ...]] = (1, 2, 3, 4, 5), eval_at: Union[List[int], Tuple[int, ...]] = (1, 2, 3, 4, 5),
feature_name: _LGBM_FeatureNameConfiguration = 'auto', feature_name: _LGBM_FeatureNameConfiguration = "auto",
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto', categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto",
callbacks: Optional[List[Callable]] = None, callbacks: Optional[List[Callable]] = None,
init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None,
) -> "LGBMRanker": ) -> "LGBMRanker":
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
# check group data # check group data
...@@ -1360,12 +1387,16 @@ class LGBMRanker(LGBMModel): ...@@ -1360,12 +1387,16 @@ class LGBMRanker(LGBMModel):
raise ValueError("Eval_group cannot be None when eval_set is not None") raise ValueError("Eval_group cannot be None when eval_set is not None")
elif len(eval_group) != len(eval_set): elif len(eval_group) != len(eval_set):
raise ValueError("Length of eval_group should be equal to eval_set") raise ValueError("Length of eval_group should be equal to eval_set")
elif (isinstance(eval_group, dict) elif (
and any(i not in eval_group or eval_group[i] is None for i in range(len(eval_group))) isinstance(eval_group, dict)
or isinstance(eval_group, list) and any(i not in eval_group or eval_group[i] is None for i in range(len(eval_group)))
and any(group is None for group in eval_group)): or isinstance(eval_group, list)
raise ValueError("Should set group for all eval datasets for ranking task; " and any(group is None for group in eval_group)
"if you use dict, the index should start from 0") ):
raise ValueError(
"Should set group for all eval datasets for ranking task; "
"if you use dict, the index should start from 0"
)
self._eval_at = eval_at self._eval_at = eval_at
super().fit( super().fit(
...@@ -1383,15 +1414,17 @@ class LGBMRanker(LGBMModel): ...@@ -1383,15 +1414,17 @@ class LGBMRanker(LGBMModel):
feature_name=feature_name, feature_name=feature_name,
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
callbacks=callbacks, callbacks=callbacks,
init_model=init_model init_model=init_model,
) )
return self return self
_base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMRanker") # type: ignore _base_doc = LGBMModel.fit.__doc__.replace("self : LGBMModel", "self : LGBMRanker") # type: ignore
fit.__doc__ = (_base_doc[:_base_doc.find('eval_class_weight :')] # type: ignore fit.__doc__ = (
+ _base_doc[_base_doc.find('eval_init_score :'):]) # type: ignore _base_doc[: _base_doc.find("eval_class_weight :")] # type: ignore
+ _base_doc[_base_doc.find("eval_init_score :") :]
) # type: ignore
_base_doc = fit.__doc__ _base_doc = fit.__doc__
_before_feature_name, _feature_name, _after_feature_name = _base_doc.partition('feature_name :') _before_feature_name, _feature_name, _after_feature_name = _base_doc.partition("feature_name :")
fit.__doc__ = f"""{_before_feature_name}eval_at : list or tuple of int, optional (default=(1, 2, 3, 4, 5)) fit.__doc__ = f"""{_before_feature_name}eval_at : list or tuple of int, optional (default=(1, 2, 3, 4, 5))
The evaluation positions of the specified metric. The evaluation positions of the specified metric.
{_feature_name}{_after_feature_name}""" {_feature_name}{_after_feature_name}"""
...@@ -114,7 +114,6 @@ exclude = [ ...@@ -114,7 +114,6 @@ exclude = [
"compile/*.py", "compile/*.py",
"external_libs/*.py", "external_libs/*.py",
"lightgbm-python/*.py", "lightgbm-python/*.py",
"python-package/*.py",
] ]
indent-style = "space" indent-style = "space"
quote-style = "double" quote-style = "double"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment