compat.py 8.38 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Compatibility library."""
3

4
from typing import Any, List
5

wxchan's avatar
wxchan committed
6
7
"""pandas"""
try:
8
    from pandas import DataFrame as pd_DataFrame
9
10
    from pandas import Series as pd_Series
    from pandas import concat
11

12
13
14
15
    try:
        from pandas import CategoricalDtype as pd_CategoricalDtype
    except ImportError:
        from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
16
    PANDAS_INSTALLED = True
wxchan's avatar
wxchan committed
17
except ImportError:
18
19
    PANDAS_INSTALLED = False

20
    class pd_Series:  # type: ignore
21
22
        """Dummy class for pandas.Series."""

23
        def __init__(self, *args: Any, **kwargs: Any):
24
            pass
wxchan's avatar
wxchan committed
25

26
    class pd_DataFrame:  # type: ignore
27
28
        """Dummy class for pandas.DataFrame."""

29
        def __init__(self, *args: Any, **kwargs: Any):
30
            pass
wxchan's avatar
wxchan committed
31

32
    class pd_CategoricalDtype:  # type: ignore
33
34
        """Dummy class for pandas.CategoricalDtype."""

35
        def __init__(self, *args: Any, **kwargs: Any):
36
37
            pass

38
    concat = None
39

40
41
42
43
"""numpy"""
try:
    from numpy.random import Generator as np_random_Generator
except ImportError:
44

45
46
47
    class np_random_Generator:  # type: ignore
        """Dummy class for np.random.Generator."""

48
        def __init__(self, *args: Any, **kwargs: Any):
49
50
            pass

51

52
53
"""matplotlib"""
try:
54
    import matplotlib  # noqa: F401
55

56
57
58
59
60
61
    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
62
    import graphviz  # noqa: F401
63

64
65
66
67
    GRAPHVIZ_INSTALLED = True
except ImportError:
    GRAPHVIZ_INSTALLED = False

68
69
"""datatable"""
try:
70
    import datatable
71

72
    if hasattr(datatable, "Frame"):
73
        dt_DataTable = datatable.Frame
74
    else:
75
        dt_DataTable = datatable.DataTable
76
77
78
79
    DATATABLE_INSTALLED = True
except ImportError:
    DATATABLE_INSTALLED = False

80
    class dt_DataTable:  # type: ignore
81
        """Dummy class for datatable.DataTable."""
82

83
        def __init__(self, *args: Any, **kwargs: Any):
84
            pass
85
86


wxchan's avatar
wxchan committed
87
88
"""sklearn"""
try:
89
    from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
wxchan's avatar
wxchan committed
90
    from sklearn.preprocessing import LabelEncoder
91
    from sklearn.utils.class_weight import compute_sample_weight
92
    from sklearn.utils.multiclass import check_classification_targets
93
    from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
94

wxchan's avatar
wxchan committed
95
    try:
96
        from sklearn.exceptions import NotFittedError
97
        from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
wxchan's avatar
wxchan committed
98
    except ImportError:
99
        from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
100
        from sklearn.utils.validation import NotFittedError
101
102
103
104
105
106
    try:
        from sklearn.utils.validation import _check_sample_weight
    except ImportError:
        from sklearn.utils.validation import check_consistent_length

        # dummy function to support older version of scikit-learn
107
        def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
108
109
110
            check_consistent_length(sample_weight, X)
            return sample_weight

wxchan's avatar
wxchan committed
111
    SKLEARN_INSTALLED = True
112
    _LGBMBaseCrossValidator = BaseCrossValidator
113
114
115
116
117
118
119
120
121
    _LGBMModelBase = BaseEstimator
    _LGBMRegressorBase = RegressorMixin
    _LGBMClassifierBase = ClassifierMixin
    _LGBMLabelEncoder = LabelEncoder
    LGBMNotFittedError = NotFittedError
    _LGBMStratifiedKFold = StratifiedKFold
    _LGBMGroupKFold = GroupKFold
    _LGBMCheckXY = check_X_y
    _LGBMCheckArray = check_array
122
    _LGBMCheckSampleWeight = _check_sample_weight
123
    _LGBMAssertAllFinite = assert_all_finite
124
    _LGBMCheckClassificationTargets = check_classification_targets
125
    _LGBMComputeSampleWeight = compute_sample_weight
wxchan's avatar
wxchan committed
126
127
except ImportError:
    SKLEARN_INSTALLED = False
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

    class _LGBMModelBase:  # type: ignore
        """Dummy class for sklearn.base.BaseEstimator."""

        pass

    class _LGBMClassifierBase:  # type: ignore
        """Dummy class for sklearn.base.ClassifierMixin."""

        pass

    class _LGBMRegressorBase:  # type: ignore
        """Dummy class for sklearn.base.RegressorMixin."""

        pass

144
    _LGBMBaseCrossValidator = None
145
146
147
148
149
150
    _LGBMLabelEncoder = None
    LGBMNotFittedError = ValueError
    _LGBMStratifiedKFold = None
    _LGBMGroupKFold = None
    _LGBMCheckXY = None
    _LGBMCheckArray = None
151
    _LGBMCheckSampleWeight = None
152
    _LGBMAssertAllFinite = None
153
    _LGBMCheckClassificationTargets = None
154
    _LGBMComputeSampleWeight = None
155
156
157

"""dask"""
try:
158
159
    from dask import delayed
    from dask.array import Array as dask_Array
160
161
    from dask.array import from_delayed as dask_array_from_delayed
    from dask.bag import from_delayed as dask_bag_from_delayed
162
163
    from dask.dataframe import DataFrame as dask_DataFrame
    from dask.dataframe import Series as dask_Series
164
    from dask.distributed import Client, Future, default_client, wait
165

166
    DASK_INSTALLED = True
167
168
169
170
171
172
173
174
175
176
177
# catching 'ValueError' here because of this:
# https://github.com/microsoft/LightGBM/issues/6365#issuecomment-2002330003
#
# That's potentially risky as dask does some significant import-time processing,
# like loading configuration from environment variables and files, and catching
# ValueError here might hide issues with that config-loading.
#
# But in exchange, it's less likely that 'import lightgbm' will fail for
# dask-related reasons, which is beneficial for any workloads that are using
# lightgbm but not its Dask functionality.
except (ImportError, ValueError):
178
    DASK_INSTALLED = False
179

180
181
    dask_array_from_delayed = None  # type: ignore[assignment]
    dask_bag_from_delayed = None  # type: ignore[assignment]
182
    delayed = None
183
184
    default_client = None  # type: ignore[assignment]
    wait = None  # type: ignore[assignment]
185

186
187
188
    class Client:  # type: ignore
        """Dummy class for dask.distributed.Client."""

189
        def __init__(self, *args: Any, **kwargs: Any):
190
            pass
191

192
193
194
    class Future:  # type: ignore
        """Dummy class for dask.distributed.Future."""

195
        def __init__(self, *args: Any, **kwargs: Any):
196
197
            pass

198
    class dask_Array:  # type: ignore
199
200
        """Dummy class for dask.array.Array."""

201
        def __init__(self, *args: Any, **kwargs: Any):
202
            pass
203

204
    class dask_DataFrame:  # type: ignore
205
206
        """Dummy class for dask.dataframe.DataFrame."""

207
        def __init__(self, *args: Any, **kwargs: Any):
208
            pass
209

210
    class dask_Series:  # type: ignore
211
        """Dummy class for dask.dataframe.Series."""
212

213
        def __init__(self, *args: Any, **kwargs: Any):
214
            pass
215

216

217
218
"""pyarrow"""
try:
219
    import pyarrow.compute as pa_compute
220
221
    from pyarrow import Array as pa_Array
    from pyarrow import ChunkedArray as pa_ChunkedArray
222
    from pyarrow import Table as pa_Table
223
    from pyarrow import chunked_array as pa_chunked_array
224
225
226
    from pyarrow.cffi import ffi as arrow_cffi
    from pyarrow.types import is_floating as arrow_is_floating
    from pyarrow.types import is_integer as arrow_is_integer
227

228
229
230
231
    PYARROW_INSTALLED = True
except ImportError:
    PYARROW_INSTALLED = False

232
233
234
    class pa_Array:  # type: ignore
        """Dummy class for pa.Array."""

235
        def __init__(self, *args: Any, **kwargs: Any):
236
237
238
239
240
            pass

    class pa_ChunkedArray:  # type: ignore
        """Dummy class for pa.ChunkedArray."""

241
        def __init__(self, *args: Any, **kwargs: Any):
242
243
            pass

244
245
246
    class pa_Table:  # type: ignore
        """Dummy class for pa.Table."""

247
        def __init__(self, *args: Any, **kwargs: Any):
248
249
250
251
252
253
254
255
256
257
            pass

    class arrow_cffi:  # type: ignore
        """Dummy class for pyarrow.cffi.ffi."""

        CData = None
        addressof = None
        cast = None
        new = None

258
        def __init__(self, *args: Any, **kwargs: Any):
259
260
            pass

261
262
263
264
265
266
    class pa_compute:  # type: ignore
        """Dummy class for pyarrow.compute."""

        all = None
        equal = None

267
    pa_chunked_array = None
268
269
270
    arrow_is_integer = None
    arrow_is_floating = None

271
272
273
274
"""cpu_count()"""
try:
    from joblib import cpu_count

275
    def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
276
277
278
279
280
        return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
    try:
        from psutil import cpu_count

281
282
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
            return cpu_count(logical=not only_physical_cores) or 1
283
284
285
    except ImportError:
        from multiprocessing import cpu_count

286
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
287
            return cpu_count()
288

289

290
__all__: List[str] = []