Unverified Commit b6c71e5e authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python][scikit-learn] change MRO (#3192)

* chanche MRO

* fix MRO resolution
parent eb7a1b7c
...@@ -94,9 +94,22 @@ try: ...@@ -94,9 +94,22 @@ try:
_LGBMComputeSampleWeight = compute_sample_weight _LGBMComputeSampleWeight = compute_sample_weight
except ImportError: except ImportError:
SKLEARN_INSTALLED = False SKLEARN_INSTALLED = False
_LGBMModelBase = object
_LGBMClassifierBase = object class _LGBMModelBase: # type: ignore
_LGBMRegressorBase = object """Dummy class for sklearn.base.BaseEstimator."""
pass
class _LGBMClassifierBase: # type: ignore
"""Dummy class for sklearn.base.ClassifierMixin."""
pass
class _LGBMRegressorBase: # type: ignore
"""Dummy class for sklearn.base.RegressorMixin."""
pass
_LGBMLabelEncoder = None _LGBMLabelEncoder = None
LGBMNotFittedError = ValueError LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None _LGBMStratifiedKFold = None
...@@ -118,11 +131,16 @@ try: ...@@ -118,11 +131,16 @@ try:
DASK_INSTALLED = True DASK_INSTALLED = True
except ImportError: except ImportError:
DASK_INSTALLED = False DASK_INSTALLED = False
delayed = None delayed = None
Client = object
default_client = None default_client = None
wait = None wait = None
class Client: # type: ignore
"""Dummy class for dask.distributed.Client."""
pass
class dask_Array: # type: ignore class dask_Array: # type: ignore
"""Dummy class for dask.array.Array.""" """Dummy class for dask.array.Array."""
......
...@@ -804,7 +804,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -804,7 +804,7 @@ class LGBMModel(_LGBMModelBase):
return self._Booster.feature_name() return self._Booster.feature_name()
class LGBMRegressor(LGBMModel, _LGBMRegressorBase): class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
"""LightGBM regressor.""" """LightGBM regressor."""
def fit(self, X, y, def fit(self, X, y,
...@@ -830,7 +830,7 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase): ...@@ -830,7 +830,7 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
+ _base_doc[_base_doc.find('eval_metric :'):]) + _base_doc[_base_doc.find('eval_metric :'):])
class LGBMClassifier(LGBMModel, _LGBMClassifierBase): class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
"""LightGBM classifier.""" """LightGBM classifier."""
def fit(self, X, y, def fit(self, X, y,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment