Unverified Commit af5b40e1 authored by Yaqub Alwan's avatar Yaqub Alwan Committed by GitHub
Browse files

[python] raise an informative error instead of segfaulting when custom...


[python] raise an informative error instead of segfaulting when custom objective produces incorrect output (#4815)

* fix for bad grads causing segfault

* adjust checking criteria to properly reflect reality of multi-class classifiers

* fix styling

* Line break before operator

* Update python-package/lightgbm/basic.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update python-package/lightgbm/basic.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* add a note to the C-API docs

* rearrange text s;ightly

* add some tests to python package

* Update include/LightGBM/c_api.h
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* PR comments

* match argument is a regex and our expression has brackets ..

* rework tests

* isorting imports

* updating test to relfect that the python APi does not take pres/labels as a fobj function
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent a55ff183
......@@ -572,6 +572,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle,
/*!
* \brief Update the model by specifying gradient and Hessian directly
* (this can be used to support customized loss functions).
* \note
* The length of the arrays referenced by ``grad`` and ``hess`` must be equal to
* ``num_class * num_train_data``, this is not verified by the library, the caller must ensure this.
* \param handle Handle of booster
* \param grad The first order derivative (gradient) statistics
* \param hess The second order derivative (Hessian) statistics
......
......@@ -3031,7 +3031,15 @@ class Booster:
assert grad.flags.c_contiguous
assert hess.flags.c_contiguous
if len(grad) != len(hess):
raise ValueError(f"Lengths of gradient({len(grad)}) and hessian({len(hess)}) don't match")
raise ValueError(f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) don't match")
num_train_data = self.train_set.num_data()
num_models = self.__num_class
if len(grad) != num_train_data * num_models:
raise ValueError(
f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) "
f"don't match training data length ({num_train_data}) * "
f"number of models per one iteration ({num_models})"
)
is_finished = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom(
self.handle,
......
# coding: utf-8
import filecmp
import numbers
import re
from pathlib import Path
import numpy as np
......@@ -579,3 +580,32 @@ def test_param_aliases():
assert all(len(i) >= 1 for i in aliases.values())
assert all(k in v for k, v in aliases.items())
assert lgb.basic._ConfigAliases.get('config', 'task') == {'config', 'config_file', 'task', 'task_type'}
def _bad_gradients(preds, _):
return np.random.randn(len(preds) + 1), np.random.rand(len(preds) + 1)
def _good_gradients(preds, _):
return np.random.randn(len(preds)), np.random.rand(len(preds))
def test_custom_objective_safety():
nrows = 100
X = np.random.randn(nrows, 5)
y_binary = np.arange(nrows) % 2
classes = [0, 1, 2]
nclass = len(classes)
y_multiclass = np.arange(nrows) % nclass
ds_binary = lgb.Dataset(X, y_binary).construct()
ds_multiclass = lgb.Dataset(X, y_multiclass).construct()
bad_bst_binary = lgb.Booster({'objective': "none"}, ds_binary)
good_bst_binary = lgb.Booster({'objective': "none"}, ds_binary)
bad_bst_multi = lgb.Booster({'objective': "none", "num_class": nclass}, ds_multiclass)
good_bst_multi = lgb.Booster({'objective': "none", "num_class": nclass}, ds_multiclass)
good_bst_binary.update(fobj=_good_gradients)
with pytest.raises(ValueError, match=re.escape("number of models per one iteration (1)")):
bad_bst_binary.update(fobj=_bad_gradients)
good_bst_multi.update(fobj=_good_gradients)
with pytest.raises(ValueError, match=re.escape(f"number of models per one iteration ({nclass})")):
bad_bst_multi.update(fobj=_bad_gradients)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment