compat.py 7.58 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Compatibility library."""
3

4
5
from typing import List

wxchan's avatar
wxchan committed
6
7
"""pandas"""
try:
8
    from pandas import DataFrame as pd_DataFrame
9
10
    from pandas import Series as pd_Series
    from pandas import concat
11
12
13
14
    try:
        from pandas import CategoricalDtype as pd_CategoricalDtype
    except ImportError:
        from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
15
    PANDAS_INSTALLED = True
wxchan's avatar
wxchan committed
16
except ImportError:
17
18
    PANDAS_INSTALLED = False

19
    class pd_Series:  # type: ignore
20
21
        """Dummy class for pandas.Series."""

22
23
        def __init__(self, *args, **kwargs):
            pass
wxchan's avatar
wxchan committed
24

25
    class pd_DataFrame:  # type: ignore
26
27
        """Dummy class for pandas.DataFrame."""

28
29
        def __init__(self, *args, **kwargs):
            pass
wxchan's avatar
wxchan committed
30

31
    class pd_CategoricalDtype:  # type: ignore
32
33
34
35
36
        """Dummy class for pandas.CategoricalDtype."""

        def __init__(self, *args, **kwargs):
            pass

37
    concat = None
38

39
40
41
42
43
44
45
46
47
48
"""numpy"""
try:
    from numpy.random import Generator as np_random_Generator
except ImportError:
    class np_random_Generator:  # type: ignore
        """Dummy class for np.random.Generator."""

        def __init__(self, *args, **kwargs):
            pass

49
50
"""matplotlib"""
try:
51
    import matplotlib  # noqa: F401
52
53
54
55
56
57
    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
58
    import graphviz  # noqa: F401
59
60
61
62
    GRAPHVIZ_INSTALLED = True
except ImportError:
    GRAPHVIZ_INSTALLED = False

63
64
"""datatable"""
try:
65
66
    import datatable
    if hasattr(datatable, "Frame"):
67
        dt_DataTable = datatable.Frame
68
    else:
69
        dt_DataTable = datatable.DataTable
70
71
72
73
    DATATABLE_INSTALLED = True
except ImportError:
    DATATABLE_INSTALLED = False

74
    class dt_DataTable:  # type: ignore
75
        """Dummy class for datatable.DataTable."""
76

77
78
        def __init__(self, *args, **kwargs):
            pass
79
80


wxchan's avatar
wxchan committed
81
82
"""sklearn"""
try:
83
    from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
wxchan's avatar
wxchan committed
84
    from sklearn.preprocessing import LabelEncoder
85
    from sklearn.utils.class_weight import compute_sample_weight
86
    from sklearn.utils.multiclass import check_classification_targets
87
    from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
wxchan's avatar
wxchan committed
88
    try:
89
        from sklearn.exceptions import NotFittedError
90
        from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
wxchan's avatar
wxchan committed
91
    except ImportError:
92
        from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
93
        from sklearn.utils.validation import NotFittedError
94
95
96
97
98
99
100
101
102
103
    try:
        from sklearn.utils.validation import _check_sample_weight
    except ImportError:
        from sklearn.utils.validation import check_consistent_length

        # dummy function to support older version of scikit-learn
        def _check_sample_weight(sample_weight, X, dtype=None):
            check_consistent_length(sample_weight, X)
            return sample_weight

wxchan's avatar
wxchan committed
104
    SKLEARN_INSTALLED = True
105
    _LGBMBaseCrossValidator = BaseCrossValidator
106
107
108
109
110
111
112
113
114
    _LGBMModelBase = BaseEstimator
    _LGBMRegressorBase = RegressorMixin
    _LGBMClassifierBase = ClassifierMixin
    _LGBMLabelEncoder = LabelEncoder
    LGBMNotFittedError = NotFittedError
    _LGBMStratifiedKFold = StratifiedKFold
    _LGBMGroupKFold = GroupKFold
    _LGBMCheckXY = check_X_y
    _LGBMCheckArray = check_array
115
    _LGBMCheckSampleWeight = _check_sample_weight
116
    _LGBMAssertAllFinite = assert_all_finite
117
    _LGBMCheckClassificationTargets = check_classification_targets
118
    _LGBMComputeSampleWeight = compute_sample_weight
wxchan's avatar
wxchan committed
119
120
except ImportError:
    SKLEARN_INSTALLED = False
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    class _LGBMModelBase:  # type: ignore
        """Dummy class for sklearn.base.BaseEstimator."""

        pass

    class _LGBMClassifierBase:  # type: ignore
        """Dummy class for sklearn.base.ClassifierMixin."""

        pass

    class _LGBMRegressorBase:  # type: ignore
        """Dummy class for sklearn.base.RegressorMixin."""

        pass

137
    _LGBMBaseCrossValidator = None
138
139
140
141
142
143
    _LGBMLabelEncoder = None
    LGBMNotFittedError = ValueError
    _LGBMStratifiedKFold = None
    _LGBMGroupKFold = None
    _LGBMCheckXY = None
    _LGBMCheckArray = None
144
    _LGBMCheckSampleWeight = None
145
    _LGBMAssertAllFinite = None
146
    _LGBMCheckClassificationTargets = None
147
    _LGBMComputeSampleWeight = None
148
149
150

"""dask"""
try:
151
152
    from dask import delayed
    from dask.array import Array as dask_Array
153
154
    from dask.array import from_delayed as dask_array_from_delayed
    from dask.bag import from_delayed as dask_bag_from_delayed
155
156
    from dask.dataframe import DataFrame as dask_DataFrame
    from dask.dataframe import Series as dask_Series
157
    from dask.distributed import Client, Future, default_client, wait
158
159
160
    DASK_INSTALLED = True
except ImportError:
    DASK_INSTALLED = False
161

162
163
    dask_array_from_delayed = None  # type: ignore[assignment]
    dask_bag_from_delayed = None  # type: ignore[assignment]
164
    delayed = None
165
166
    default_client = None  # type: ignore[assignment]
    wait = None  # type: ignore[assignment]
167

168
169
170
    class Client:  # type: ignore
        """Dummy class for dask.distributed.Client."""

171
172
        def __init__(self, *args, **kwargs):
            pass
173

174
175
176
177
178
179
    class Future:  # type: ignore
        """Dummy class for dask.distributed.Future."""

        def __init__(self, *args, **kwargs):
            pass

180
    class dask_Array:  # type: ignore
181
182
        """Dummy class for dask.array.Array."""

183
184
        def __init__(self, *args, **kwargs):
            pass
185

186
    class dask_DataFrame:  # type: ignore
187
188
        """Dummy class for dask.dataframe.DataFrame."""

189
190
        def __init__(self, *args, **kwargs):
            pass
191

192
    class dask_Series:  # type: ignore
193
        """Dummy class for dask.dataframe.Series."""
194

195
196
        def __init__(self, *args, **kwargs):
            pass
197

198
199
"""pyarrow"""
try:
200
    import pyarrow.compute as pa_compute
201
202
    from pyarrow import Array as pa_Array
    from pyarrow import ChunkedArray as pa_ChunkedArray
203
204
205
206
207
208
209
210
    from pyarrow import Table as pa_Table
    from pyarrow.cffi import ffi as arrow_cffi
    from pyarrow.types import is_floating as arrow_is_floating
    from pyarrow.types import is_integer as arrow_is_integer
    PYARROW_INSTALLED = True
except ImportError:
    PYARROW_INSTALLED = False

211
212
213
214
215
216
217
218
219
220
221
222
    class pa_Array:  # type: ignore
        """Dummy class for pa.Array."""

        def __init__(self, *args, **kwargs):
            pass

    class pa_ChunkedArray:  # type: ignore
        """Dummy class for pa.ChunkedArray."""

        def __init__(self, *args, **kwargs):
            pass

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    class pa_Table:  # type: ignore
        """Dummy class for pa.Table."""

        def __init__(self, *args, **kwargs):
            pass

    class arrow_cffi:  # type: ignore
        """Dummy class for pyarrow.cffi.ffi."""

        CData = None
        addressof = None
        cast = None
        new = None

        def __init__(self, *args, **kwargs):
            pass

240
241
242
243
244
245
    class pa_compute:  # type: ignore
        """Dummy class for pyarrow.compute."""

        all = None
        equal = None

246
247
248
    arrow_is_integer = None
    arrow_is_floating = None

249
250
251
252
"""cpu_count()"""
try:
    from joblib import cpu_count

253
    def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
254
255
256
257
258
        return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
    try:
        from psutil import cpu_count

259
260
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
            return cpu_count(logical=not only_physical_cores) or 1
261
262
263
    except ImportError:
        from multiprocessing import cpu_count

264
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
265
            return cpu_count()
266

267
__all__: List[str] = []