Unverified Commit b1e000c0 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[dask] remove unused private _client attribute (#3904)

* Update test_dask.py

* Update dask.py

* Update .vsts-ci.yml

* Revert "Update .vsts-ci.yml"

This reverts commit 98422be5b5095f0585de333b5b5545356776ef88.
parent 08c68c91
...@@ -468,11 +468,9 @@ class _DaskLGBMModel: ...@@ -468,11 +468,9 @@ class _DaskLGBMModel:
def _lgb_getstate(self) -> Dict[Any, Any]: def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization.""" """Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None) client = self.__dict__.pop("client", None)
self.__dict__.pop("_client", None)
self._other_params.pop("client", None) self._other_params.pop("client", None)
out = deepcopy(self.__dict__) out = deepcopy(self.__dict__)
out.update({"_client": None, "client": None}) out.update({"client": None})
self._client = client
self.client = client self.client = client
return out return out
...@@ -521,7 +519,6 @@ class _DaskLGBMModel: ...@@ -521,7 +519,6 @@ class _DaskLGBMModel:
attributes = source.__dict__ attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys()) extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names: for name in extra_param_names:
if name != "_client":
setattr(dest, name, attributes[name]) setattr(dest, name, attributes[name])
...@@ -554,7 +551,6 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): ...@@ -554,7 +551,6 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
**kwargs: Any **kwargs: Any
): ):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" """Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self._client = client
self.client = client self.client = client
super().__init__( super().__init__(
boosting_type=boosting_type, boosting_type=boosting_type,
...@@ -672,7 +668,6 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): ...@@ -672,7 +668,6 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
**kwargs: Any **kwargs: Any
): ):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" """Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self._client = client
self.client = client self.client = client
super().__init__( super().__init__(
boosting_type=boosting_type, boosting_type=boosting_type,
...@@ -779,7 +774,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): ...@@ -779,7 +774,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
**kwargs: Any **kwargs: Any
): ):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" """Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self._client = client
self.client = client self.client = client
super().__init__( super().__init__(
boosting_type=boosting_type, boosting_type=boosting_type,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
"""Tests for lightgbm.dask module""" """Tests for lightgbm.dask module"""
import inspect import inspect
import joblib
import pickle import pickle
import socket import socket
from itertools import groupby from itertools import groupby
...@@ -19,6 +18,7 @@ if not lgb.compat.DASK_INSTALLED: ...@@ -19,6 +18,7 @@ if not lgb.compat.DASK_INSTALLED:
import cloudpickle import cloudpickle
import dask.array as da import dask.array as da
import dask.dataframe as dd import dask.dataframe as dd
import joblib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from scipy.stats import spearmanr from scipy.stats import spearmanr
...@@ -488,34 +488,29 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l ...@@ -488,34 +488,29 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
# should be able to use the class without specifying a client # should be able to use the class without specifying a client
dask_model = model_factory(**params) dask_model = model_factory(**params)
assert dask_model._client is None
assert dask_model.client is None assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_ dask_model.client_
dask_model.fit(dX, dy, group=dg) dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_ assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None assert dask_model.client is None
assert dask_model.client_ == client assert dask_model.client_ == client
preds = dask_model.predict(dX) preds = dask_model.predict(dX)
assert isinstance(preds, da.Array) assert isinstance(preds, da.Array)
assert dask_model.fitted_ assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None assert dask_model.client is None
assert dask_model.client_ == client assert dask_model.client_ == client
local_model = dask_model.to_local() local_model = dask_model.to_local()
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
local_model._client
local_model.client local_model.client
local_model.client_ local_model.client_
# should be able to set client after construction # should be able to set client after construction
dask_model = model_factory(**params) dask_model = model_factory(**params)
dask_model.set_params(client=client) dask_model.set_params(client=client)
assert dask_model._client == client
assert dask_model.client == client assert dask_model.client == client
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
...@@ -523,21 +518,17 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l ...@@ -523,21 +518,17 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
dask_model.fit(dX, dy, group=dg) dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_ assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client assert dask_model.client == client
assert dask_model.client_ == client assert dask_model.client_ == client
preds = dask_model.predict(dX) preds = dask_model.predict(dX)
assert isinstance(preds, da.Array) assert isinstance(preds, da.Array)
assert dask_model.fitted_ assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client assert dask_model.client == client
assert dask_model.client_ == client assert dask_model.client_ == client
local_model = dask_model.to_local() local_model = dask_model.to_local()
assert getattr(local_model, "_client", None) is None
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
local_model._client
local_model.client local_model.client
local_model.client_ local_model.client_
...@@ -606,10 +597,8 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -606,10 +597,8 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
dask_model = model_factory(**params) dask_model = model_factory(**params)
local_model = dask_model.to_local() local_model = dask_model.to_local()
if set_client: if set_client:
assert dask_model._client == client1
assert dask_model.client == client1 assert dask_model.client == client1
else: else:
assert dask_model._client is None
assert dask_model.client is None assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
...@@ -640,14 +629,11 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -640,14 +629,11 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
serializer=serializer serializer=serializer
) )
assert model_from_disk._client is None
assert model_from_disk.client is None assert model_from_disk.client is None
if set_client: if set_client:
assert dask_model._client == client1
assert dask_model.client == client1 assert dask_model.client == client1
else: else:
assert dask_model._client is None
assert dask_model.client is None assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'): with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
...@@ -674,7 +660,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -674,7 +660,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert "client" not in local_model.get_params() assert "client" not in local_model.get_params()
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
local_model._client
local_model.client local_model.client
local_model.client_ local_model.client_
...@@ -701,17 +686,14 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -701,17 +686,14 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
) )
if set_client: if set_client:
assert dask_model._client == client1
assert dask_model.client == client1 assert dask_model.client == client1
assert dask_model.client_ == client1 assert dask_model.client_ == client1
else: else:
assert dask_model._client is None
assert dask_model.client is None assert dask_model.client is None
assert dask_model.client_ == default_client() assert dask_model.client_ == default_client()
assert dask_model.client_ == client2 assert dask_model.client_ == client2
assert isinstance(fitted_model_from_disk, model_factory) assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk._client is None
assert fitted_model_from_disk.client is None assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client() assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2 assert fitted_model_from_disk.client_ == client2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment