compat.py 8.3 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Compatibility library."""
3

4
from typing import Any, List
5

6
7
# scikit-learn is intentionally imported first here,
# see https://github.com/microsoft/LightGBM/issues/6509
wxchan's avatar
wxchan committed
8
9
"""sklearn"""
try:
10
    from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
wxchan's avatar
wxchan committed
11
    from sklearn.preprocessing import LabelEncoder
12
    from sklearn.utils.class_weight import compute_sample_weight
13
    from sklearn.utils.multiclass import check_classification_targets
14
    from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
15

wxchan's avatar
wxchan committed
16
    try:
17
        from sklearn.exceptions import NotFittedError
18
        from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
wxchan's avatar
wxchan committed
19
    except ImportError:
20
        from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
21
        from sklearn.utils.validation import NotFittedError
22
23
24
25
26
27
    try:
        from sklearn.utils.validation import _check_sample_weight
    except ImportError:
        from sklearn.utils.validation import check_consistent_length

        # dummy function to support older version of scikit-learn
28
        def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
29
30
31
            check_consistent_length(sample_weight, X)
            return sample_weight

wxchan's avatar
wxchan committed
32
    SKLEARN_INSTALLED = True
33
    _LGBMBaseCrossValidator = BaseCrossValidator
34
35
36
37
38
39
40
41
42
    _LGBMModelBase = BaseEstimator
    _LGBMRegressorBase = RegressorMixin
    _LGBMClassifierBase = ClassifierMixin
    _LGBMLabelEncoder = LabelEncoder
    LGBMNotFittedError = NotFittedError
    _LGBMStratifiedKFold = StratifiedKFold
    _LGBMGroupKFold = GroupKFold
    _LGBMCheckXY = check_X_y
    _LGBMCheckArray = check_array
43
    _LGBMCheckSampleWeight = _check_sample_weight
44
    _LGBMAssertAllFinite = assert_all_finite
45
    _LGBMCheckClassificationTargets = check_classification_targets
46
    _LGBMComputeSampleWeight = compute_sample_weight
wxchan's avatar
wxchan committed
47
48
except ImportError:
    SKLEARN_INSTALLED = False
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    class _LGBMModelBase:  # type: ignore
        """Dummy class for sklearn.base.BaseEstimator."""

        pass

    class _LGBMClassifierBase:  # type: ignore
        """Dummy class for sklearn.base.ClassifierMixin."""

        pass

    class _LGBMRegressorBase:  # type: ignore
        """Dummy class for sklearn.base.RegressorMixin."""

        pass

65
    _LGBMBaseCrossValidator = None
66
67
68
69
70
71
    _LGBMLabelEncoder = None
    LGBMNotFittedError = ValueError
    _LGBMStratifiedKFold = None
    _LGBMGroupKFold = None
    _LGBMCheckXY = None
    _LGBMCheckArray = None
72
    _LGBMCheckSampleWeight = None
73
    _LGBMAssertAllFinite = None
74
    _LGBMCheckClassificationTargets = None
75
    _LGBMComputeSampleWeight = None
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""pandas"""
try:
    from pandas import DataFrame as pd_DataFrame
    from pandas import Series as pd_Series
    from pandas import concat

    try:
        from pandas import CategoricalDtype as pd_CategoricalDtype
    except ImportError:
        from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
    PANDAS_INSTALLED = True
except ImportError:
    PANDAS_INSTALLED = False

    class pd_Series:  # type: ignore
        """Dummy class for pandas.Series."""

        def __init__(self, *args: Any, **kwargs: Any):
            pass

    class pd_DataFrame:  # type: ignore
        """Dummy class for pandas.DataFrame."""

        def __init__(self, *args: Any, **kwargs: Any):
            pass

    class pd_CategoricalDtype:  # type: ignore
        """Dummy class for pandas.CategoricalDtype."""

        def __init__(self, *args: Any, **kwargs: Any):
            pass

    concat = None

"""matplotlib"""
try:
    import matplotlib  # noqa: F401

    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
    import graphviz  # noqa: F401

    GRAPHVIZ_INSTALLED = True
except ImportError:
    GRAPHVIZ_INSTALLED = False

"""datatable"""
try:
    import datatable

    if hasattr(datatable, "Frame"):
        dt_DataTable = datatable.Frame
    else:
        dt_DataTable = datatable.DataTable
    DATATABLE_INSTALLED = True
except ImportError:
    DATATABLE_INSTALLED = False

    class dt_DataTable:  # type: ignore
        """Dummy class for datatable.DataTable."""

        def __init__(self, *args: Any, **kwargs: Any):
            pass


146
147
"""dask"""
try:
148
149
    from dask import delayed
    from dask.array import Array as dask_Array
150
151
    from dask.array import from_delayed as dask_array_from_delayed
    from dask.bag import from_delayed as dask_bag_from_delayed
152
153
    from dask.dataframe import DataFrame as dask_DataFrame
    from dask.dataframe import Series as dask_Series
154
    from dask.distributed import Client, Future, default_client, wait
155

156
    DASK_INSTALLED = True
157
158
159
160
161
162
163
164
165
166
167
# catching 'ValueError' here because of this:
# https://github.com/microsoft/LightGBM/issues/6365#issuecomment-2002330003
#
# That's potentially risky as dask does some significant import-time processing,
# like loading configuration from environment variables and files, and catching
# ValueError here might hide issues with that config-loading.
#
# But in exchange, it's less likely that 'import lightgbm' will fail for
# dask-related reasons, which is beneficial for any workloads that are using
# lightgbm but not its Dask functionality.
except (ImportError, ValueError):
168
    DASK_INSTALLED = False
169

170
171
    dask_array_from_delayed = None  # type: ignore[assignment]
    dask_bag_from_delayed = None  # type: ignore[assignment]
172
    delayed = None
173
174
    default_client = None  # type: ignore[assignment]
    wait = None  # type: ignore[assignment]
175

176
177
178
    class Client:  # type: ignore
        """Dummy class for dask.distributed.Client."""

179
        def __init__(self, *args: Any, **kwargs: Any):
180
            pass
181

182
183
184
    class Future:  # type: ignore
        """Dummy class for dask.distributed.Future."""

185
        def __init__(self, *args: Any, **kwargs: Any):
186
187
            pass

188
    class dask_Array:  # type: ignore
189
190
        """Dummy class for dask.array.Array."""

191
        def __init__(self, *args: Any, **kwargs: Any):
192
            pass
193

194
    class dask_DataFrame:  # type: ignore
195
196
        """Dummy class for dask.dataframe.DataFrame."""

197
        def __init__(self, *args: Any, **kwargs: Any):
198
            pass
199

200
    class dask_Series:  # type: ignore
201
        """Dummy class for dask.dataframe.Series."""
202

203
        def __init__(self, *args: Any, **kwargs: Any):
204
            pass
205

206

207
208
"""pyarrow"""
try:
209
    import pyarrow.compute as pa_compute
210
211
    from pyarrow import Array as pa_Array
    from pyarrow import ChunkedArray as pa_ChunkedArray
212
    from pyarrow import Table as pa_Table
213
    from pyarrow import chunked_array as pa_chunked_array
214
    from pyarrow.cffi import ffi as arrow_cffi
215
    from pyarrow.types import is_boolean as arrow_is_boolean
216
217
    from pyarrow.types import is_floating as arrow_is_floating
    from pyarrow.types import is_integer as arrow_is_integer
218

219
220
221
222
    PYARROW_INSTALLED = True
except ImportError:
    PYARROW_INSTALLED = False

223
224
225
    class pa_Array:  # type: ignore
        """Dummy class for pa.Array."""

226
        def __init__(self, *args: Any, **kwargs: Any):
227
228
229
230
231
            pass

    class pa_ChunkedArray:  # type: ignore
        """Dummy class for pa.ChunkedArray."""

232
        def __init__(self, *args: Any, **kwargs: Any):
233
234
            pass

235
236
237
    class pa_Table:  # type: ignore
        """Dummy class for pa.Table."""

238
        def __init__(self, *args: Any, **kwargs: Any):
239
240
241
242
243
244
245
246
247
248
            pass

    class arrow_cffi:  # type: ignore
        """Dummy class for pyarrow.cffi.ffi."""

        CData = None
        addressof = None
        cast = None
        new = None

249
        def __init__(self, *args: Any, **kwargs: Any):
250
251
            pass

252
253
254
255
256
257
    class pa_compute:  # type: ignore
        """Dummy class for pyarrow.compute."""

        all = None
        equal = None

258
    pa_chunked_array = None
259
    arrow_is_boolean = None
260
261
262
    arrow_is_integer = None
    arrow_is_floating = None

263
264
265
266
"""cpu_count()"""
try:
    from joblib import cpu_count

267
    def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
268
269
270
271
272
        return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
    try:
        from psutil import cpu_count

273
274
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
            return cpu_count(logical=not only_physical_cores) or 1
275
276
277
    except ImportError:
        from multiprocessing import cpu_count

278
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
279
            return cpu_count()
280

281

282
__all__: List[str] = []