Unverified Commit 237ac299 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] handle arbitrary length feature names in Python-package (#4293)

* handle arbitrary length feature names in Python-package

* added tests
parent 41a1a242
......@@ -1875,7 +1875,7 @@ class Dataset:
tmp_out_len = ctypes.c_int(0)
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range(num_feature)]
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_DatasetGetFeatureNames(
self.handle,
......@@ -1886,11 +1886,18 @@ class Dataset:
ptr_string_buffers))
if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
f"Allocated feature name buffer size ({reserved_string_buffer_size}) was"
f"inferior to the needed size ({required_string_buffer_size.value})."
)
actual_string_buffer_size = required_string_buffer_size.value
# if buffer length is not long enough, reallocate buffers
if reserved_string_buffer_size < actual_string_buffer_size:
string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_DatasetGetFeatureNames(
self.handle,
ctypes.c_int(num_feature),
ctypes.byref(tmp_out_len),
ctypes.c_size_t(actual_string_buffer_size),
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]
def get_label(self):
......@@ -3249,7 +3256,7 @@ class Booster:
tmp_out_len = ctypes.c_int(0)
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for i in range(num_feature)]
string_buffers = [ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetFeatureNames(
self.handle,
......@@ -3260,9 +3267,18 @@ class Booster:
ptr_string_buffers))
if num_feature != tmp_out_len.value:
raise ValueError("Length of feature names doesn't equal with num_feature")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
f"Allocated feature name buffer size ({reserved_string_buffer_size}) was inferior to the needed size ({required_string_buffer_size.value}).")
actual_string_buffer_size = required_string_buffer_size.value
# if buffer length is not long enough, reallocate buffers
if reserved_string_buffer_size < actual_string_buffer_size:
string_buffers = [ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(num_feature)]
ptr_string_buffers = (ctypes.c_char_p * num_feature)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetFeatureNames(
self.handle,
ctypes.c_int(num_feature),
ctypes.byref(tmp_out_len),
ctypes.c_size_t(actual_string_buffer_size),
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]
def feature_importance(self, importance_type='split', iteration=None):
......@@ -3445,7 +3461,7 @@ class Booster:
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [
ctypes.create_string_buffer(reserved_string_buffer_size) for i in range(self.__num_inner_eval)
ctypes.create_string_buffer(reserved_string_buffer_size) for _ in range(self.__num_inner_eval)
]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
......@@ -3457,13 +3473,26 @@ class Booster:
ptr_string_buffers))
if self.__num_inner_eval != tmp_out_len.value:
raise ValueError("Length of eval names doesn't equal with num_evals")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
f"Allocated eval name buffer size ({reserved_string_buffer_size}) was inferior to the needed size ({required_string_vuffer_size.value}).")
self.__name_inner_eval = \
[string_buffers[i].value.decode('utf-8') for i in range(self.__num_inner_eval)]
self.__higher_better_inner_eval = \
[name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval]
actual_string_buffer_size = required_string_buffer_size.value
# if buffer length is not long enough, reallocate buffers
if reserved_string_buffer_size < actual_string_buffer_size:
string_buffers = [
ctypes.create_string_buffer(actual_string_buffer_size) for _ in range(self.__num_inner_eval)
]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
self.handle,
ctypes.c_int(self.__num_inner_eval),
ctypes.byref(tmp_out_len),
ctypes.c_size_t(actual_string_buffer_size),
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
self.__name_inner_eval = [
string_buffers[i].value.decode('utf-8') for i in range(self.__num_inner_eval)
]
self.__higher_better_inner_eval = [
name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval
]
def attr(self, key):
"""Get attribute string from the Booster.
......
......@@ -16,7 +16,9 @@ from .utils import load_breast_cancer
def test_basic(tmp_path):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True),
test_size=0.1, random_state=2)
train_data = lgb.Dataset(X_train, label=y_train)
feature_names = [f"Column_{i}" for i in range(X_train.shape[1])]
feature_names[1] = "a" * 1000 # set one name to a value longer than default buffer size
train_data = lgb.Dataset(X_train, label=y_train, feature_name=feature_names)
valid_data = train_data.create_valid(X_test, label=y_test)
params = {
......@@ -37,6 +39,8 @@ def test_basic(tmp_path):
if i % 10 == 0:
print(bst.eval_train(), bst.eval_valid())
assert train_data.get_feature_name() == feature_names
assert bst.current_iteration() == 20
assert bst.num_trees() == 20
assert bst.num_model_per_iteration() == 1
......@@ -55,6 +59,7 @@ def test_basic(tmp_path):
# check saved model persistence
bst = lgb.Booster(params, model_file=model_file)
assert bst.feature_name() == feature_names
pred_from_model_file = bst.predict(X_test)
# we need to check the consistency of model file here, so test for exact equal
np.testing.assert_array_equal(pred_from_matr, pred_from_model_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment