"src/vscode:/vscode.git/clone" did not exist on "c4c83bc787319db06c7b175888f8dfbb13b1a27f"
compat.py 12.1 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Compatibility library."""
3

4
from typing import TYPE_CHECKING, Any, List
5

6
7
# scikit-learn is intentionally imported first here,
# see https://github.com/microsoft/LightGBM/issues/6509
wxchan's avatar
wxchan committed
8
9
"""sklearn"""
try:
10
    from sklearn import __version__ as _sklearn_version
11
    from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
wxchan's avatar
wxchan committed
12
    from sklearn.preprocessing import LabelEncoder
13
    from sklearn.utils.class_weight import compute_sample_weight
14
    from sklearn.utils.multiclass import check_classification_targets
15
    from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
16

17
18
19
20
21
22
23
24
    # sklearn.utils Tags types can be imported unconditionally once
    # lightgbm's minimum scikit-learn version is 1.6 or higher
    try:
        from sklearn.utils import ClassifierTags as _sklearn_ClassifierTags
        from sklearn.utils import RegressorTags as _sklearn_RegressorTags
    except ImportError:
        _sklearn_ClassifierTags = None
        _sklearn_RegressorTags = None
wxchan's avatar
wxchan committed
25
    try:
26
        from sklearn.exceptions import NotFittedError
27
        from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
wxchan's avatar
wxchan committed
28
    except ImportError:
29
        from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
30
        from sklearn.utils.validation import NotFittedError
31
32
33
34
35
36
    try:
        from sklearn.utils.validation import _check_sample_weight
    except ImportError:
        from sklearn.utils.validation import check_consistent_length

        # dummy function to support older version of scikit-learn
37
        def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
38
39
40
            check_consistent_length(sample_weight, X)
            return sample_weight

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    try:
        from sklearn.utils.validation import validate_data
    except ImportError:
        # validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
        # It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
        def validate_data(
            _estimator,
            X,
            y="no_validation",
            accept_sparse: bool = True,
            # 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
            ensure_all_finite: bool = False,
            ensure_min_samples: int = 1,
            # trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
            **ignored_kwargs,
        ):
            # it's safe to import _num_features unconditionally because:
            #
            #  * it was first added in scikit-learn 0.24.2
            #  * lightgbm cannot be used with scikit-learn versions older than that
            #  * this validate_data() re-implementation will not be called in scikit-learn>=1.6
            #
            from sklearn.utils.validation import _num_features

            # _num_features() raises a TypeError on 1-dimensional input. That's a problem
            # because scikit-learn's 'check_fit1d' estimator check sets that expectation that
            # estimators must raise a ValueError when a 1-dimensional input is passed to fit().
            #
            # So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
            if hasattr(X, "shape") and len(X.shape) == 1:
                n_features_in_ = 1
            else:
                n_features_in_ = _num_features(X)

            no_val_y = isinstance(y, str) and y == "no_validation"

            # NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
            if no_val_y:
                X = check_array(
                    X,
                    accept_sparse=accept_sparse,
                    force_all_finite=ensure_all_finite,
                    ensure_min_samples=ensure_min_samples,
                )
            else:
                X, y = check_X_y(
                    X,
                    y,
                    accept_sparse=accept_sparse,
                    force_all_finite=ensure_all_finite,
                    ensure_min_samples=ensure_min_samples,
                )

                # this only needs to be updated at fit() time
                _estimator.n_features_in_ = n_features_in_

            # raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
            if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
                raise ValueError(
                    f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
                    f"is expecting {_estimator._n_features} features as input."
                )

            if no_val_y:
                return X
            else:
                return X, y

wxchan's avatar
wxchan committed
109
    SKLEARN_INSTALLED = True
110
    _LGBMBaseCrossValidator = BaseCrossValidator
111
112
113
114
115
116
117
    _LGBMModelBase = BaseEstimator
    _LGBMRegressorBase = RegressorMixin
    _LGBMClassifierBase = ClassifierMixin
    _LGBMLabelEncoder = LabelEncoder
    LGBMNotFittedError = NotFittedError
    _LGBMStratifiedKFold = StratifiedKFold
    _LGBMGroupKFold = GroupKFold
118
    _LGBMCheckSampleWeight = _check_sample_weight
119
    _LGBMAssertAllFinite = assert_all_finite
120
    _LGBMCheckClassificationTargets = check_classification_targets
121
    _LGBMComputeSampleWeight = compute_sample_weight
122
    _LGBMValidateData = validate_data
wxchan's avatar
wxchan committed
123
124
except ImportError:
    SKLEARN_INSTALLED = False
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

    class _LGBMModelBase:  # type: ignore
        """Dummy class for sklearn.base.BaseEstimator."""

        pass

    class _LGBMClassifierBase:  # type: ignore
        """Dummy class for sklearn.base.ClassifierMixin."""

        pass

    class _LGBMRegressorBase:  # type: ignore
        """Dummy class for sklearn.base.RegressorMixin."""

        pass

141
    _LGBMBaseCrossValidator = None
142
143
144
145
    _LGBMLabelEncoder = None
    LGBMNotFittedError = ValueError
    _LGBMStratifiedKFold = None
    _LGBMGroupKFold = None
146
    _LGBMCheckSampleWeight = None
147
    _LGBMAssertAllFinite = None
148
    _LGBMCheckClassificationTargets = None
149
    _LGBMComputeSampleWeight = None
150
    _LGBMValidateData = None
151
152
    _sklearn_ClassifierTags = None
    _sklearn_RegressorTags = None
153
154
155
156
157
158
159
160
161
162
163
    _sklearn_version = None

# additional scikit-learn imports only for type hints
if TYPE_CHECKING:
    # sklearn.utils.Tags can be imported unconditionally once
    # lightgbm's minimum scikit-learn version is 1.6 or higher
    try:
        from sklearn.utils import Tags as _sklearn_Tags
    except ImportError:
        _sklearn_Tags = None

164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""pandas"""
try:
    from pandas import DataFrame as pd_DataFrame
    from pandas import Series as pd_Series
    from pandas import concat

    try:
        from pandas import CategoricalDtype as pd_CategoricalDtype
    except ImportError:
        from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
    PANDAS_INSTALLED = True
except ImportError:
    PANDAS_INSTALLED = False

    class pd_Series:  # type: ignore
        """Dummy class for pandas.Series."""

        def __init__(self, *args: Any, **kwargs: Any):
            pass

    class pd_DataFrame:  # type: ignore
        """Dummy class for pandas.DataFrame."""

        def __init__(self, *args: Any, **kwargs: Any):
            pass

    class pd_CategoricalDtype:  # type: ignore
        """Dummy class for pandas.CategoricalDtype."""

        def __init__(self, *args: Any, **kwargs: Any):
            pass

    concat = None

"""matplotlib"""
try:
    import matplotlib  # noqa: F401

    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
    import graphviz  # noqa: F401

    GRAPHVIZ_INSTALLED = True
except ImportError:
    GRAPHVIZ_INSTALLED = False

"""datatable"""
try:
    import datatable

    if hasattr(datatable, "Frame"):
        dt_DataTable = datatable.Frame
    else:
        dt_DataTable = datatable.DataTable
    DATATABLE_INSTALLED = True
except ImportError:
    DATATABLE_INSTALLED = False

    class dt_DataTable:  # type: ignore
        """Dummy class for datatable.DataTable."""

        def __init__(self, *args: Any, **kwargs: Any):
            pass


234
235
"""dask"""
try:
236
237
    from dask import delayed
    from dask.array import Array as dask_Array
238
239
    from dask.array import from_delayed as dask_array_from_delayed
    from dask.bag import from_delayed as dask_bag_from_delayed
240
241
    from dask.dataframe import DataFrame as dask_DataFrame
    from dask.dataframe import Series as dask_Series
242
    from dask.distributed import Client, Future, default_client, wait
243

244
    DASK_INSTALLED = True
245
246
247
248
249
250
251
252
253
254
255
# catching 'ValueError' here because of this:
# https://github.com/microsoft/LightGBM/issues/6365#issuecomment-2002330003
#
# That's potentially risky as dask does some significant import-time processing,
# like loading configuration from environment variables and files, and catching
# ValueError here might hide issues with that config-loading.
#
# But in exchange, it's less likely that 'import lightgbm' will fail for
# dask-related reasons, which is beneficial for any workloads that are using
# lightgbm but not its Dask functionality.
except (ImportError, ValueError):
256
    DASK_INSTALLED = False
257

258
259
    dask_array_from_delayed = None  # type: ignore[assignment]
    dask_bag_from_delayed = None  # type: ignore[assignment]
260
    delayed = None
261
262
    default_client = None  # type: ignore[assignment]
    wait = None  # type: ignore[assignment]
263

264
265
266
    class Client:  # type: ignore
        """Dummy class for dask.distributed.Client."""

267
        def __init__(self, *args: Any, **kwargs: Any):
268
            pass
269

270
271
272
    class Future:  # type: ignore
        """Dummy class for dask.distributed.Future."""

273
        def __init__(self, *args: Any, **kwargs: Any):
274
275
            pass

276
    class dask_Array:  # type: ignore
277
278
        """Dummy class for dask.array.Array."""

279
        def __init__(self, *args: Any, **kwargs: Any):
280
            pass
281

282
    class dask_DataFrame:  # type: ignore
283
284
        """Dummy class for dask.dataframe.DataFrame."""

285
        def __init__(self, *args: Any, **kwargs: Any):
286
            pass
287

288
    class dask_Series:  # type: ignore
289
        """Dummy class for dask.dataframe.Series."""
290

291
        def __init__(self, *args: Any, **kwargs: Any):
292
            pass
293

294

295
296
"""pyarrow"""
try:
297
    import pyarrow.compute as pa_compute
298
299
    from pyarrow import Array as pa_Array
    from pyarrow import ChunkedArray as pa_ChunkedArray
300
    from pyarrow import Table as pa_Table
301
    from pyarrow import chunked_array as pa_chunked_array
302
    from pyarrow.cffi import ffi as arrow_cffi
303
    from pyarrow.types import is_boolean as arrow_is_boolean
304
305
    from pyarrow.types import is_floating as arrow_is_floating
    from pyarrow.types import is_integer as arrow_is_integer
306

307
308
309
310
    PYARROW_INSTALLED = True
except ImportError:
    PYARROW_INSTALLED = False

311
312
313
    class pa_Array:  # type: ignore
        """Dummy class for pa.Array."""

314
        def __init__(self, *args: Any, **kwargs: Any):
315
316
317
318
319
            pass

    class pa_ChunkedArray:  # type: ignore
        """Dummy class for pa.ChunkedArray."""

320
        def __init__(self, *args: Any, **kwargs: Any):
321
322
            pass

323
324
325
    class pa_Table:  # type: ignore
        """Dummy class for pa.Table."""

326
        def __init__(self, *args: Any, **kwargs: Any):
327
328
329
330
331
332
333
334
335
336
            pass

    class arrow_cffi:  # type: ignore
        """Dummy class for pyarrow.cffi.ffi."""

        CData = None
        addressof = None
        cast = None
        new = None

337
        def __init__(self, *args: Any, **kwargs: Any):
338
339
            pass

340
341
342
343
344
345
    class pa_compute:  # type: ignore
        """Dummy class for pyarrow.compute."""

        all = None
        equal = None

346
    pa_chunked_array = None
347
    arrow_is_boolean = None
348
349
350
    arrow_is_integer = None
    arrow_is_floating = None

351
352
353
354
"""cpu_count()"""
try:
    from joblib import cpu_count

355
    def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
356
357
358
359
360
        return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
    try:
        from psutil import cpu_count

361
362
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
            return cpu_count(logical=not only_physical_cores) or 1
363
364
365
    except ImportError:
        from multiprocessing import cpu_count

366
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
367
            return cpu_count()
368

369

370
__all__: List[str] = []