Unverified Commit 29796eee authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package] add Booster.set_leaf_output method (#5712)

parent 6f0bc481
...@@ -4043,6 +4043,38 @@ class Booster: ...@@ -4043,6 +4043,38 @@ class Booster:
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
def set_leaf_output(
self,
tree_id: int,
leaf_id: int,
value: float,
) -> 'Booster':
"""Set the output of a leaf.
Parameters
----------
tree_id : int
The index of the tree.
leaf_id : int
The index of the leaf in the tree.
value : float
Value to set as the output of the leaf.
Returns
-------
self : Booster
Booster with the leaf output set.
"""
_safe_call(
_LIB.LGBM_BoosterSetLeafValue(
self.handle,
ctypes.c_int(tree_id),
ctypes.c_int(leaf_id),
ctypes.c_double(value)
)
)
return self
def _to_predictor( def _to_predictor(
self, self,
pred_parameter: Optional[Dict[str, Any]] = None pred_parameter: Optional[Dict[str, Any]] = None
......
...@@ -793,3 +793,15 @@ def test_feature_num_bin_with_max_bin_by_feature(): ...@@ -793,3 +793,15 @@ def test_feature_num_bin_with_max_bin_by_feature():
ds = lgb.Dataset(X, params={'max_bin_by_feature': max_bin_by_feature}).construct() ds = lgb.Dataset(X, params={'max_bin_by_feature': max_bin_by_feature}).construct()
actual_num_bins = [ds.feature_num_bin(i) for i in range(X.shape[1])] actual_num_bins = [ds.feature_num_bin(i) for i in range(X.shape[1])]
np.testing.assert_equal(actual_num_bins, max_bin_by_feature) np.testing.assert_equal(actual_num_bins, max_bin_by_feature)
def test_set_leaf_output():
X, y = load_breast_cancer(return_X_y=True)
ds = lgb.Dataset(X, y)
bst = lgb.Booster({'num_leaves': 2}, ds)
bst.update()
y_pred = bst.predict(X)
for leaf_id in range(2):
leaf_output = bst.get_leaf_output(tree_id=0, leaf_id=leaf_id)
bst.set_leaf_output(tree_id=0, leaf_id=leaf_id, value=leaf_output + 1)
np.testing.assert_allclose(bst.predict(X), y_pred + 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment