Unverified Commit ac3668fa authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] ignore some mypy errors in _DaskLGBMModel (#5774)

parent 38554095
...@@ -1035,7 +1035,7 @@ class _DaskLGBMModel: ...@@ -1035,7 +1035,7 @@ class _DaskLGBMModel:
def _lgb_dask_getstate(self) -> Dict[Any, Any]: def _lgb_dask_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization.""" """Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None) client = self.__dict__.pop("client", None)
self._other_params.pop("client", None) self._other_params.pop("client", None) # type: ignore[attr-defined]
out = deepcopy(self.__dict__) out = deepcopy(self.__dict__)
out.update({"client": None}) out.update({"client": None})
self.client = client self.client = client
...@@ -1064,7 +1064,7 @@ class _DaskLGBMModel: ...@@ -1064,7 +1064,7 @@ class _DaskLGBMModel:
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
params = self.get_params(True) params = self.get_params(True) # type: ignore[attr-defined]
params.pop("client", None) params.pop("client", None)
model = _train( model = _train(
...@@ -1087,13 +1087,13 @@ class _DaskLGBMModel: ...@@ -1087,13 +1087,13 @@ class _DaskLGBMModel:
**kwargs **kwargs
) )
self.set_params(**model.get_params()) self.set_params(**model.get_params()) # type: ignore[attr-defined]
self._lgb_dask_copy_extra_params(model, self) self._lgb_dask_copy_extra_params(model, self) # type: ignore[attr-defined]
return self return self
def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
params = self.get_params() params = self.get_params() # type: ignore[attr-defined]
params.pop("client", None) params.pop("client", None)
model = model_factory(**params) model = model_factory(**params)
self._lgb_dask_copy_extra_params(self, model) self._lgb_dask_copy_extra_params(self, model)
...@@ -1102,7 +1102,7 @@ class _DaskLGBMModel: ...@@ -1102,7 +1102,7 @@ class _DaskLGBMModel:
@staticmethod @staticmethod
def _lgb_dask_copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None: def _lgb_dask_copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
params = source.get_params() params = source.get_params() # type: ignore[union-attr]
attributes = source.__dict__ attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys()) extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names: for name in extra_param_names:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment