Unverified Commit 11110c54 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] remove `Booster.set_attr()` and `Booster.attr()` (#5272)

parent f1328d5c
...@@ -2570,7 +2570,6 @@ class Booster: ...@@ -2570,7 +2570,6 @@ class Booster:
self.network = False self.network = False
self.__need_reload_eval_info = True self.__need_reload_eval_info = True
self._train_data_name = "training" self._train_data_name = "training"
self.__attr = {}
self.__set_objective_to_none = False self.__set_objective_to_none = False
self.best_iteration = -1 self.best_iteration = -1
self.best_score = {} self.best_score = {}
...@@ -3652,7 +3651,6 @@ class Booster: ...@@ -3652,7 +3651,6 @@ class Booster:
ctypes.c_int32(nrow), ctypes.c_int32(nrow),
ctypes.c_int32(ncol))) ctypes.c_int32(ncol)))
new_booster.network = self.network new_booster.network = self.network
new_booster.__attr = self.__attr.copy()
return new_booster return new_booster
def get_leaf_output(self, tree_id, leaf_id): def get_leaf_output(self, tree_id, leaf_id):
...@@ -3956,42 +3954,3 @@ class Booster: ...@@ -3956,42 +3954,3 @@ class Booster:
self.__higher_better_inner_eval = [ self.__higher_better_inner_eval = [
name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval
] ]
def attr(self, key: str) -> Optional[str]:
"""Get attribute string from the Booster.
Parameters
----------
key : str
The name of the attribute.
Returns
-------
value : str or None
The attribute value.
Returns None if attribute does not exist.
"""
return self.__attr.get(key, None)
def set_attr(self, **kwargs: Any) -> "Booster":
"""Set attributes to the Booster.
Parameters
----------
**kwargs
The attributes to set.
Setting a value to None deletes an attribute.
Returns
-------
self : Booster
Booster with set attributes.
"""
for key, value in kwargs.items():
if value is not None:
if not isinstance(value, str):
raise ValueError("Only string values are accepted")
self.__attr[key] = value
else:
self.__attr.pop(key, None)
return self
...@@ -29,7 +29,7 @@ class UnpicklableCallback: ...@@ -29,7 +29,7 @@ class UnpicklableCallback:
raise Exception("This class in not picklable") raise Exception("This class in not picklable")
def __call__(self, env): def __call__(self, env):
env.model.set_attr(attr_set_inside_callback=str(env.iteration * 10)) env.model.attr_set_inside_callback = env.iteration * 10
def custom_asymmetric_obj(y_true, y_pred): def custom_asymmetric_obj(y_true, y_pred):
...@@ -480,7 +480,7 @@ def test_non_serializable_objects_in_callbacks(tmp_path): ...@@ -480,7 +480,7 @@ def test_non_serializable_objects_in_callbacks(tmp_path):
X, y = make_synthetic_regression() X, y = make_synthetic_regression()
gbm = lgb.LGBMRegressor(n_estimators=5) gbm = lgb.LGBMRegressor(n_estimators=5)
gbm.fit(X, y, callbacks=[unpicklable_callback]) gbm.fit(X, y, callbacks=[unpicklable_callback])
assert gbm.booster_.attr('attr_set_inside_callback') == '40' assert gbm.booster_.attr_set_inside_callback == 40
def test_random_state_object(): def test_random_state_object():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment