Unverified Commit b1e000c0 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[dask] remove unused private _client attribute (#3904)

* Update test_dask.py

* Update dask.py

* Update .vsts-ci.yml

* Revert "Update .vsts-ci.yml"

This reverts commit 98422be5b5095f0585de333b5b5545356776ef88.
parent 08c68c91
......@@ -468,11 +468,9 @@ class _DaskLGBMModel:
def _lgb_getstate(self) -> Dict[Any, Any]:
"""Remove un-picklable attributes before serialization."""
client = self.__dict__.pop("client", None)
self.__dict__.pop("_client", None)
self._other_params.pop("client", None)
out = deepcopy(self.__dict__)
out.update({"_client": None, "client": None})
self._client = client
out.update({"client": None})
self.client = client
return out
......@@ -521,7 +519,6 @@ class _DaskLGBMModel:
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
for name in extra_param_names:
if name != "_client":
setattr(dest, name, attributes[name])
......@@ -554,7 +551,6 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
......@@ -672,7 +668,6 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
......@@ -779,7 +774,6 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
self._client = client
self.client = client
super().__init__(
boosting_type=boosting_type,
......
......@@ -2,7 +2,6 @@
"""Tests for lightgbm.dask module"""
import inspect
import joblib
import pickle
import socket
from itertools import groupby
......@@ -19,6 +18,7 @@ if not lgb.compat.DASK_INSTALLED:
import cloudpickle
import dask.array as da
import dask.dataframe as dd
import joblib
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
......@@ -488,34 +488,29 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
# should be able to use the class without specifying a client
dask_model = model_factory(**params)
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
dask_model.client_
dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == client
local_model = dask_model.to_local()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
# should be able to set client after construction
dask_model = model_factory(**params)
dask_model.set_params(client=client)
assert dask_model._client == client
assert dask_model.client == client
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
......@@ -523,21 +518,17 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, l
dask_model.fit(dX, dy, group=dg)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client
preds = dask_model.predict(dX)
assert isinstance(preds, da.Array)
assert dask_model.fitted_
assert dask_model._client == client
assert dask_model.client == client
assert dask_model.client_ == client
local_model = dask_model.to_local()
assert getattr(local_model, "_client", None) is None
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
......@@ -606,10 +597,8 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
dask_model = model_factory(**params)
local_model = dask_model.to_local()
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
......@@ -640,14 +629,11 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
serializer=serializer
)
assert model_from_disk._client is None
assert model_from_disk.client is None
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
else:
assert dask_model._client is None
assert dask_model.client is None
with pytest.raises(lgb.compat.LGBMNotFittedError, match='Cannot access property client_ before calling fit'):
......@@ -674,7 +660,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert "client" not in local_model.get_params()
with pytest.raises(AttributeError):
local_model._client
local_model.client
local_model.client_
......@@ -701,17 +686,14 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
)
if set_client:
assert dask_model._client == client1
assert dask_model.client == client1
assert dask_model.client_ == client1
else:
assert dask_model._client is None
assert dask_model.client is None
assert dask_model.client_ == default_client()
assert dask_model.client_ == client2
assert isinstance(fitted_model_from_disk, model_factory)
assert fitted_model_from_disk._client is None
assert fitted_model_from_disk.client is None
assert fitted_model_from_disk.client_ == default_client()
assert fitted_model_from_disk.client_ == client2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment