compat.py 6.77 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Compatibility library."""
3

4
5
from typing import List

wxchan's avatar
wxchan committed
6
7
"""pandas"""
try:
8
    from pandas import DataFrame as pd_DataFrame
9
10
    from pandas import Series as pd_Series
    from pandas import concat
11
12
13
14
    try:
        from pandas import CategoricalDtype as pd_CategoricalDtype
    except ImportError:
        from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
15
    PANDAS_INSTALLED = True
wxchan's avatar
wxchan committed
16
except ImportError:
17
18
    PANDAS_INSTALLED = False

19
    class pd_Series:  # type: ignore
20
21
        """Dummy class for pandas.Series."""

22
23
        def __init__(self, *args, **kwargs):
            pass
wxchan's avatar
wxchan committed
24

25
    class pd_DataFrame:  # type: ignore
26
27
        """Dummy class for pandas.DataFrame."""

28
29
        def __init__(self, *args, **kwargs):
            pass
wxchan's avatar
wxchan committed
30

31
    class pd_CategoricalDtype:  # type: ignore
32
33
34
35
36
        """Dummy class for pandas.CategoricalDtype."""

        def __init__(self, *args, **kwargs):
            pass

37
    concat = None
38

39
40
"""matplotlib"""
try:
41
    import matplotlib  # noqa: F401
42
43
44
45
46
47
    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
48
    import graphviz  # noqa: F401
49
50
51
52
    GRAPHVIZ_INSTALLED = True
except ImportError:
    GRAPHVIZ_INSTALLED = False

53
54
"""datatable"""
try:
55
56
    import datatable
    if hasattr(datatable, "Frame"):
57
        dt_DataTable = datatable.Frame
58
    else:
59
        dt_DataTable = datatable.DataTable
60
61
62
63
    DATATABLE_INSTALLED = True
except ImportError:
    DATATABLE_INSTALLED = False

64
    class dt_DataTable:  # type: ignore
65
        """Dummy class for datatable.DataTable."""
66

67
68
        def __init__(self, *args, **kwargs):
            pass
69
70


wxchan's avatar
wxchan committed
71
72
"""sklearn"""
try:
73
    from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
wxchan's avatar
wxchan committed
74
    from sklearn.preprocessing import LabelEncoder
75
    from sklearn.utils.class_weight import compute_sample_weight
76
    from sklearn.utils.multiclass import check_classification_targets
77
    from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
wxchan's avatar
wxchan committed
78
    try:
79
        from sklearn.exceptions import NotFittedError
80
        from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
wxchan's avatar
wxchan committed
81
    except ImportError:
82
        from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
83
        from sklearn.utils.validation import NotFittedError
84
85
86
87
88
89
90
91
92
93
    try:
        from sklearn.utils.validation import _check_sample_weight
    except ImportError:
        from sklearn.utils.validation import check_consistent_length

        # dummy function to support older version of scikit-learn
        def _check_sample_weight(sample_weight, X, dtype=None):
            check_consistent_length(sample_weight, X)
            return sample_weight

wxchan's avatar
wxchan committed
94
    SKLEARN_INSTALLED = True
95
    _LGBMBaseCrossValidator = BaseCrossValidator
96
97
98
99
100
101
102
103
104
    _LGBMModelBase = BaseEstimator
    _LGBMRegressorBase = RegressorMixin
    _LGBMClassifierBase = ClassifierMixin
    _LGBMLabelEncoder = LabelEncoder
    LGBMNotFittedError = NotFittedError
    _LGBMStratifiedKFold = StratifiedKFold
    _LGBMGroupKFold = GroupKFold
    _LGBMCheckXY = check_X_y
    _LGBMCheckArray = check_array
105
    _LGBMCheckSampleWeight = _check_sample_weight
106
    _LGBMAssertAllFinite = assert_all_finite
107
    _LGBMCheckClassificationTargets = check_classification_targets
108
    _LGBMComputeSampleWeight = compute_sample_weight
wxchan's avatar
wxchan committed
109
110
except ImportError:
    SKLEARN_INSTALLED = False
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    class _LGBMModelBase:  # type: ignore
        """Dummy class for sklearn.base.BaseEstimator."""

        pass

    class _LGBMClassifierBase:  # type: ignore
        """Dummy class for sklearn.base.ClassifierMixin."""

        pass

    class _LGBMRegressorBase:  # type: ignore
        """Dummy class for sklearn.base.RegressorMixin."""

        pass

127
    _LGBMBaseCrossValidator = None
128
129
130
131
132
133
    _LGBMLabelEncoder = None
    LGBMNotFittedError = ValueError
    _LGBMStratifiedKFold = None
    _LGBMGroupKFold = None
    _LGBMCheckXY = None
    _LGBMCheckArray = None
134
    _LGBMCheckSampleWeight = None
135
    _LGBMAssertAllFinite = None
136
    _LGBMCheckClassificationTargets = None
137
    _LGBMComputeSampleWeight = None
138
139
140

"""dask"""
try:
141
142
    from dask import delayed
    from dask.array import Array as dask_Array
143
144
    from dask.array import from_delayed as dask_array_from_delayed
    from dask.bag import from_delayed as dask_bag_from_delayed
145
146
    from dask.dataframe import DataFrame as dask_DataFrame
    from dask.dataframe import Series as dask_Series
147
    from dask.distributed import Client, Future, default_client, wait
148
149
150
    DASK_INSTALLED = True
except ImportError:
    DASK_INSTALLED = False
151

152
153
    dask_array_from_delayed = None  # type: ignore[assignment]
    dask_bag_from_delayed = None  # type: ignore[assignment]
154
    delayed = None
155
156
    default_client = None  # type: ignore[assignment]
    wait = None  # type: ignore[assignment]
157

158
159
160
    class Client:  # type: ignore
        """Dummy class for dask.distributed.Client."""

161
162
        def __init__(self, *args, **kwargs):
            pass
163

164
165
166
167
168
169
    class Future:  # type: ignore
        """Dummy class for dask.distributed.Future."""

        def __init__(self, *args, **kwargs):
            pass

170
    class dask_Array:  # type: ignore
171
172
        """Dummy class for dask.array.Array."""

173
174
        def __init__(self, *args, **kwargs):
            pass
175

176
    class dask_DataFrame:  # type: ignore
177
178
        """Dummy class for dask.dataframe.DataFrame."""

179
180
        def __init__(self, *args, **kwargs):
            pass
181

182
    class dask_Series:  # type: ignore
183
        """Dummy class for dask.dataframe.Series."""
184

185
186
        def __init__(self, *args, **kwargs):
            pass
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""pyarrow"""
try:
    from pyarrow import Table as pa_Table
    from pyarrow.cffi import ffi as arrow_cffi
    from pyarrow.types import is_floating as arrow_is_floating
    from pyarrow.types import is_integer as arrow_is_integer
    PYARROW_INSTALLED = True
except ImportError:
    PYARROW_INSTALLED = False

    class pa_Table:  # type: ignore
        """Dummy class for pa.Table."""

        def __init__(self, *args, **kwargs):
            pass

    class arrow_cffi:  # type: ignore
        """Dummy class for pyarrow.cffi.ffi."""

        CData = None
        addressof = None
        cast = None
        new = None

        def __init__(self, *args, **kwargs):
            pass

    arrow_is_integer = None
    arrow_is_floating = None

218
219
220
221
"""cpu_count()"""
try:
    from joblib import cpu_count

222
    def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
223
224
225
226
227
        return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
    try:
        from psutil import cpu_count

228
229
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
            return cpu_count(logical=not only_physical_cores) or 1
230
231
232
    except ImportError:
        from multiprocessing import cpu_count

233
        def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
234
            return cpu_count()
235

236
__all__: List[str] = []