compat.py 5.64 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Compatibility library."""
3

wxchan's avatar
wxchan committed
4
5
"""pandas"""
try:
6
    from pandas import DataFrame as pd_DataFrame
7
8
    from pandas import Series as pd_Series
    from pandas import concat
9
10
11
12
    try:
        from pandas import CategoricalDtype as pd_CategoricalDtype
    except ImportError:
        from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
13
    PANDAS_INSTALLED = True
wxchan's avatar
wxchan committed
14
except ImportError:
15
16
    PANDAS_INSTALLED = False

17
    class pd_Series:  # type: ignore
18
19
        """Dummy class for pandas.Series."""

20
21
        def __init__(self, *args, **kwargs):
            pass
wxchan's avatar
wxchan committed
22

23
    class pd_DataFrame:  # type: ignore
24
25
        """Dummy class for pandas.DataFrame."""

26
27
        def __init__(self, *args, **kwargs):
            pass
wxchan's avatar
wxchan committed
28

29
    class pd_CategoricalDtype:  # type: ignore
30
31
32
33
34
        """Dummy class for pandas.CategoricalDtype."""

        def __init__(self, *args, **kwargs):
            pass

35
    concat = None
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""matplotlib"""
try:
    import matplotlib
    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
    import graphviz
    GRAPHVIZ_INSTALLED = True
except ImportError:
    GRAPHVIZ_INSTALLED = False

51
52
"""datatable"""
try:
53
54
    import datatable
    if hasattr(datatable, "Frame"):
55
        dt_DataTable = datatable.Frame
56
    else:
57
        dt_DataTable = datatable.DataTable
58
59
60
61
    DATATABLE_INSTALLED = True
except ImportError:
    DATATABLE_INSTALLED = False

62
    class dt_DataTable:  # type: ignore
63
        """Dummy class for datatable.DataTable."""
64

65
66
        def __init__(self, *args, **kwargs):
            pass
67
68


wxchan's avatar
wxchan committed
69
70
"""sklearn"""
try:
71
    from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
wxchan's avatar
wxchan committed
72
    from sklearn.preprocessing import LabelEncoder
73
    from sklearn.utils.class_weight import compute_sample_weight
74
    from sklearn.utils.multiclass import check_classification_targets
75
    from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
wxchan's avatar
wxchan committed
76
    try:
77
        from sklearn.exceptions import NotFittedError
78
        from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
wxchan's avatar
wxchan committed
79
    except ImportError:
80
        from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
81
        from sklearn.utils.validation import NotFittedError
82
83
84
85
86
87
88
89
90
91
    try:
        from sklearn.utils.validation import _check_sample_weight
    except ImportError:
        from sklearn.utils.validation import check_consistent_length

        # dummy function to support older version of scikit-learn
        def _check_sample_weight(sample_weight, X, dtype=None):
            check_consistent_length(sample_weight, X)
            return sample_weight

wxchan's avatar
wxchan committed
92
    SKLEARN_INSTALLED = True
93
    _LGBMBaseCrossValidator = BaseCrossValidator
94
95
96
97
98
99
100
101
102
    _LGBMModelBase = BaseEstimator
    _LGBMRegressorBase = RegressorMixin
    _LGBMClassifierBase = ClassifierMixin
    _LGBMLabelEncoder = LabelEncoder
    LGBMNotFittedError = NotFittedError
    _LGBMStratifiedKFold = StratifiedKFold
    _LGBMGroupKFold = GroupKFold
    _LGBMCheckXY = check_X_y
    _LGBMCheckArray = check_array
103
    _LGBMCheckSampleWeight = _check_sample_weight
104
    _LGBMAssertAllFinite = assert_all_finite
105
    _LGBMCheckClassificationTargets = check_classification_targets
106
    _LGBMComputeSampleWeight = compute_sample_weight
wxchan's avatar
wxchan committed
107
108
except ImportError:
    SKLEARN_INSTALLED = False
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

    class _LGBMModelBase:  # type: ignore
        """Dummy class for sklearn.base.BaseEstimator."""

        pass

    class _LGBMClassifierBase:  # type: ignore
        """Dummy class for sklearn.base.ClassifierMixin."""

        pass

    class _LGBMRegressorBase:  # type: ignore
        """Dummy class for sklearn.base.RegressorMixin."""

        pass

125
126
127
128
129
130
    _LGBMLabelEncoder = None
    LGBMNotFittedError = ValueError
    _LGBMStratifiedKFold = None
    _LGBMGroupKFold = None
    _LGBMCheckXY = None
    _LGBMCheckArray = None
131
    _LGBMCheckSampleWeight = None
132
    _LGBMAssertAllFinite = None
133
    _LGBMCheckClassificationTargets = None
134
    _LGBMComputeSampleWeight = None
135
136
137

"""dask"""
try:
138
139
    from dask import delayed
    from dask.array import Array as dask_Array
140
141
    from dask.array import from_delayed as dask_array_from_delayed
    from dask.bag import from_delayed as dask_bag_from_delayed
142
143
    from dask.dataframe import DataFrame as dask_DataFrame
    from dask.dataframe import Series as dask_Series
144
    from dask.distributed import Client, default_client, wait
145
146
147
    DASK_INSTALLED = True
except ImportError:
    DASK_INSTALLED = False
148

149
150
    dask_array_from_delayed = None
    dask_bag_from_delayed = None
151
152
153
154
    delayed = None
    default_client = None
    wait = None

155
156
157
    class Client:  # type: ignore
        """Dummy class for dask.distributed.Client."""

158
159
        def __init__(self, *args, **kwargs):
            pass
160

161
    class dask_Array:  # type: ignore
162
163
        """Dummy class for dask.array.Array."""

164
165
        def __init__(self, *args, **kwargs):
            pass
166

167
    class dask_DataFrame:  # type: ignore
168
169
        """Dummy class for dask.dataframe.DataFrame."""

170
171
        def __init__(self, *args, **kwargs):
            pass
172

173
    class dask_Series:  # type: ignore
174
        """Dummy class for dask.dataframe.Series."""
175

176
177
        def __init__(self, *args, **kwargs):
            pass
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

"""cpu_count()"""
try:
    from joblib import cpu_count

    def _LGBMCpuCount(only_physical_cores: bool = True):
        return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
    try:
        from psutil import cpu_count

        def _LGBMCpuCount(only_physical_cores: bool = True):
            return cpu_count(logical=not only_physical_cores)
    except ImportError:
        from multiprocessing import cpu_count

        def _LGBMCpuCount(only_physical_cores: bool = True):
            return cpu_count()