Unverified Commit 646267d2 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] use more specific method names on _DaskLGBMModel (#4004)

parent 7f91dc66
......@@ -465,7 +465,7 @@ class _DaskLGBMModel:
return _get_dask_client(client=self.client)
def _lgb_getstate(self) -> Dict[Any, Any]:
def _lgb_dask_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self._other_params.pop("client", None)
......@@ -474,7 +474,7 @@ class _DaskLGBMModel:
self.client = client
return out
def _fit(
def _lgb_dask_fit(
self,
model_factory: Type[LGBMModel],
X: _DaskMatrixLike,
......@@ -501,20 +501,20 @@ class _DaskLGBMModel:
)
self.set_params(**model.get_params())
self._copy_extra_params(model, self)
self._lgb_dask_copy_extra_params(model, self)
return self
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
def _lgb_dask_to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
params = self.get_params()
params.pop("client", None)
model = model_factory(**params)
self._copy_extra_params(self, model)
self._lgb_dask_copy_extra_params(self, model)
model._other_params.pop("client", None)
return model
@staticmethod
def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
def _lgb_dask_copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
params = source.get_params()
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
......@@ -590,7 +590,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()
def fit(
self,
......@@ -600,7 +600,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMClassifier,
X=X,
y=y,
......@@ -670,7 +670,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
model : lightgbm.LGBMClassifier
Local underlying model.
"""
return self._to_local(LGBMClassifier)
return self._lgb_dask_to_local(LGBMClassifier)
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
......@@ -741,7 +741,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()
def fit(
self,
......@@ -751,7 +751,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
**kwargs: Any
) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMRegressor,
X=X,
y=y,
......@@ -802,7 +802,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
model : lightgbm.LGBMRegressor
Local underlying model.
"""
return self._to_local(LGBMRegressor)
return self._lgb_dask_to_local(LGBMRegressor)
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
......@@ -873,7 +873,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]
def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_getstate()
return self._lgb_dask_getstate()
def fit(
self,
......@@ -888,7 +888,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')
return self._fit(
return self._lgb_dask_fit(
model_factory=LGBMRanker,
X=X,
y=y,
......@@ -939,4 +939,4 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
model : lightgbm.LGBMRanker
Local underlying model.
"""
return self._to_local(LGBMRanker)
return self._lgb_dask_to_local(LGBMRanker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment