Unverified Commit 646267d2 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] use more specific method names on _DaskLGBMModel (#4004)

parent 7f91dc66
...@@ -465,7 +465,7 @@ class _DaskLGBMModel: ...@@ -465,7 +465,7 @@ class _DaskLGBMModel:
return _get_dask_client(client=self.client) return _get_dask_client(client=self.client)
def _lgb_getstate(self) -> Dict[Any, Any]: def _lgb_dask_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization.""" """Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None) client = self.__dict__.pop("client", None)
self._other_params.pop("client", None) self._other_params.pop("client", None)
...@@ -474,7 +474,7 @@ class _DaskLGBMModel: ...@@ -474,7 +474,7 @@ class _DaskLGBMModel:
self.client = client self.client = client
return out return out
def _fit( def _lgb_dask_fit(
self, self,
model_factory: Type[LGBMModel], model_factory: Type[LGBMModel],
X: _DaskMatrixLike, X: _DaskMatrixLike,
...@@ -501,20 +501,20 @@ class _DaskLGBMModel: ...@@ -501,20 +501,20 @@ class _DaskLGBMModel:
) )
self.set_params(**model.get_params()) self.set_params(**model.get_params())
self._copy_extra_params(model, self) self._lgb_dask_copy_extra_params(model, self)
return self return self
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
params = self.get_params() params = self.get_params()
params.pop("client", None) params.pop("client", None)
model = model_factory(**params) model = model_factory(**params)
self._copy_extra_params(self, model) self._lgb_dask_copy_extra_params(self, model)
model._other_params.pop("client", None) model._other_params.pop("client", None)
return model return model
@staticmethod @staticmethod
def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None: def _lgb_dask_copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
params = source.get_params() params = source.get_params()
attributes = source.__dict__ attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys()) extra_param_names = set(attributes.keys()).difference(params.keys())
...@@ -590,7 +590,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -590,7 +590,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')] __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]
def __getstate__(self) -> Dict[Any, Any]: def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate() return self._lgb_dask_getstate()
def fit( def fit(
self, self,
...@@ -600,7 +600,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -600,7 +600,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
**kwargs: Any **kwargs: Any
) -> "DaskLGBMClassifier": ) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" """Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit( return self._lgb_dask_fit(
model_factory=LGBMClassifier, model_factory=LGBMClassifier,
X=X, X=X,
y=y, y=y,
...@@ -670,7 +670,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -670,7 +670,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
model : lightgbm.LGBMClassifier model : lightgbm.LGBMClassifier
Local underlying model. Local underlying model.
""" """
return self._to_local(LGBMClassifier) return self._lgb_dask_to_local(LGBMClassifier)
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
...@@ -741,7 +741,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -741,7 +741,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')] __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]
def __getstate__(self) -> Dict[Any, Any]: def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate() return self._lgb_dask_getstate()
def fit( def fit(
self, self,
...@@ -751,7 +751,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -751,7 +751,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
**kwargs: Any **kwargs: Any
) -> "DaskLGBMRegressor": ) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" """Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit( return self._lgb_dask_fit(
model_factory=LGBMRegressor, model_factory=LGBMRegressor,
X=X, X=X,
y=y, y=y,
...@@ -802,7 +802,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -802,7 +802,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
model : lightgbm.LGBMRegressor model : lightgbm.LGBMRegressor
Local underlying model. Local underlying model.
""" """
return self._to_local(LGBMRegressor) return self._lgb_dask_to_local(LGBMRegressor)
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
...@@ -873,7 +873,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -873,7 +873,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')] __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]
def __getstate__(self) -> Dict[Any, Any]: def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate() return self._lgb_dask_getstate()
def fit( def fit(
self, self,
...@@ -888,7 +888,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -888,7 +888,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
if init_score is not None: if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask') raise RuntimeError('init_score is not currently supported in lightgbm.dask')
return self._fit( return self._lgb_dask_fit(
model_factory=LGBMRanker, model_factory=LGBMRanker,
X=X, X=X,
y=y, y=y,
...@@ -939,4 +939,4 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -939,4 +939,4 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
model : lightgbm.LGBMRanker model : lightgbm.LGBMRanker
Local underlying model. Local underlying model.
""" """
return self._to_local(LGBMRanker) return self._lgb_dask_to_local(LGBMRanker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment