compat.py 4.79 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Compatibility library."""
3

wxchan's avatar
wxchan committed
4
5
"""pandas"""
try:
6
    from pandas import DataFrame as pd_DataFrame
7
8
    from pandas import Series as pd_Series
    from pandas import concat
9
    from pandas.api.types import is_sparse as is_dtype_sparse
10
    PANDAS_INSTALLED = True
wxchan's avatar
wxchan committed
11
except ImportError:
12
13
    PANDAS_INSTALLED = False

14
    class pd_Series:  # type: ignore
15
16
        """Dummy class for pandas.Series."""

17
18
        def __init__(self, *args, **kwargs):
            pass
wxchan's avatar
wxchan committed
19

20
    class pd_DataFrame:  # type: ignore
21
22
        """Dummy class for pandas.DataFrame."""

23
24
        def __init__(self, *args, **kwargs):
            pass
wxchan's avatar
wxchan committed
25

26
    concat = None
27
28
    is_dtype_sparse = None

29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""matplotlib"""
try:
    import matplotlib
    MATPLOTLIB_INSTALLED = True
except ImportError:
    MATPLOTLIB_INSTALLED = False

"""graphviz"""
try:
    import graphviz
    GRAPHVIZ_INSTALLED = True
except ImportError:
    GRAPHVIZ_INSTALLED = False

43
44
"""datatable"""
try:
45
46
    import datatable
    if hasattr(datatable, "Frame"):
47
        dt_DataTable = datatable.Frame
48
    else:
49
        dt_DataTable = datatable.DataTable
50
51
52
53
    DATATABLE_INSTALLED = True
except ImportError:
    DATATABLE_INSTALLED = False

54
    class dt_DataTable:  # type: ignore
55
        """Dummy class for datatable.DataTable."""
56

57
58
        def __init__(self, *args, **kwargs):
            pass
59
60


wxchan's avatar
wxchan committed
61
62
"""sklearn"""
try:
63
    from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
wxchan's avatar
wxchan committed
64
    from sklearn.preprocessing import LabelEncoder
65
    from sklearn.utils.class_weight import compute_sample_weight
66
    from sklearn.utils.multiclass import check_classification_targets
67
    from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
wxchan's avatar
wxchan committed
68
    try:
69
        from sklearn.exceptions import NotFittedError
70
        from sklearn.model_selection import GroupKFold, StratifiedKFold
wxchan's avatar
wxchan committed
71
    except ImportError:
72
        from sklearn.cross_validation import GroupKFold, StratifiedKFold
73
        from sklearn.utils.validation import NotFittedError
74
75
76
77
78
79
80
81
82
83
    try:
        from sklearn.utils.validation import _check_sample_weight
    except ImportError:
        from sklearn.utils.validation import check_consistent_length

        # dummy function to support older version of scikit-learn
        def _check_sample_weight(sample_weight, X, dtype=None):
            check_consistent_length(sample_weight, X)
            return sample_weight

wxchan's avatar
wxchan committed
84
    SKLEARN_INSTALLED = True
85
86
87
88
89
90
91
92
93
    _LGBMModelBase = BaseEstimator
    _LGBMRegressorBase = RegressorMixin
    _LGBMClassifierBase = ClassifierMixin
    _LGBMLabelEncoder = LabelEncoder
    LGBMNotFittedError = NotFittedError
    _LGBMStratifiedKFold = StratifiedKFold
    _LGBMGroupKFold = GroupKFold
    _LGBMCheckXY = check_X_y
    _LGBMCheckArray = check_array
94
    _LGBMCheckSampleWeight = _check_sample_weight
95
    _LGBMAssertAllFinite = assert_all_finite
96
    _LGBMCheckClassificationTargets = check_classification_targets
97
    _LGBMComputeSampleWeight = compute_sample_weight
wxchan's avatar
wxchan committed
98
99
except ImportError:
    SKLEARN_INSTALLED = False
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    class _LGBMModelBase:  # type: ignore
        """Dummy class for sklearn.base.BaseEstimator."""

        pass

    class _LGBMClassifierBase:  # type: ignore
        """Dummy class for sklearn.base.ClassifierMixin."""

        pass

    class _LGBMRegressorBase:  # type: ignore
        """Dummy class for sklearn.base.RegressorMixin."""

        pass

116
117
118
119
120
121
    _LGBMLabelEncoder = None
    LGBMNotFittedError = ValueError
    _LGBMStratifiedKFold = None
    _LGBMGroupKFold = None
    _LGBMCheckXY = None
    _LGBMCheckArray = None
122
    _LGBMCheckSampleWeight = None
123
    _LGBMAssertAllFinite = None
124
    _LGBMCheckClassificationTargets = None
125
    _LGBMComputeSampleWeight = None
126
127
128

"""dask"""
try:
129
130
    from dask import delayed
    from dask.array import Array as dask_Array
131
132
    from dask.array import from_delayed as dask_array_from_delayed
    from dask.bag import from_delayed as dask_bag_from_delayed
133
134
    from dask.dataframe import DataFrame as dask_DataFrame
    from dask.dataframe import Series as dask_Series
135
    from dask.distributed import Client, default_client, wait
136
137
138
    DASK_INSTALLED = True
except ImportError:
    DASK_INSTALLED = False
139

140
141
    dask_array_from_delayed = None
    dask_bag_from_delayed = None
142
143
144
145
    delayed = None
    default_client = None
    wait = None

146
147
148
    class Client:  # type: ignore
        """Dummy class for dask.distributed.Client."""

149
150
        def __init__(self, *args, **kwargs):
            pass
151

152
    class dask_Array:  # type: ignore
153
154
        """Dummy class for dask.array.Array."""

155
156
        def __init__(self, *args, **kwargs):
            pass
157

158
    class dask_DataFrame:  # type: ignore
159
160
        """Dummy class for dask.dataframe.DataFrame."""

161
162
        def __init__(self, *args, **kwargs):
            pass
163

164
    class dask_Series:  # type: ignore
165
        """Dummy class for dask.dataframe.Series."""
166

167
168
        def __init__(self, *args, **kwargs):
            pass