Unverified Commit e61bcbec authored by Malte Londschien's avatar Malte Londschien Committed by GitHub
Browse files

[python-package] Infer feature names from `pyarrow.Table` (#6781)

parent e0c34e7b
...@@ -2126,6 +2126,8 @@ class Dataset: ...@@ -2126,6 +2126,8 @@ class Dataset:
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
pandas_categorical=self.pandas_categorical, pandas_categorical=self.pandas_categorical,
) )
elif _is_pyarrow_table(data) and feature_name == "auto":
feature_name = data.column_names
# process for args # process for args
params = {} if params is None else params params = {} if params is None else params
...@@ -2185,7 +2187,6 @@ class Dataset: ...@@ -2185,7 +2187,6 @@ class Dataset:
self.__init_from_np2d(data, params_str, ref_dataset) self.__init_from_np2d(data, params_str, ref_dataset)
elif _is_pyarrow_table(data): elif _is_pyarrow_table(data):
self.__init_from_pyarrow_table(data, params_str, ref_dataset) self.__init_from_pyarrow_table(data, params_str, ref_dataset)
feature_name = data.column_names
elif isinstance(data, list) and len(data) > 0: elif isinstance(data, list) and len(data) > 0:
if _is_list_of_numpy_arrays(data): if _is_list_of_numpy_arrays(data):
self.__init_from_list_np2d(data, params_str, ref_dataset) self.__init_from_list_np2d(data, params_str, ref_dataset)
......
...@@ -432,3 +432,25 @@ def test_predict_ranking(): ...@@ -432,3 +432,25 @@ def test_predict_ranking():
num_boost_round=5, num_boost_round=5,
) )
assert_equal_predict_arrow_pandas(booster, data) assert_equal_predict_arrow_pandas(booster, data)
def test_arrow_feature_name_auto():
data = generate_dummy_arrow_table()
dataset = lgb.Dataset(
data, label=pa.array([0, 1, 0, 0, 1]), params=dummy_dataset_params(), categorical_feature=["a"]
)
booster = lgb.train({"num_leaves": 7}, dataset, num_boost_round=5)
assert booster.feature_name() == ["a", "b"]
def test_arrow_feature_name_manual():
data = generate_dummy_arrow_table()
dataset = lgb.Dataset(
data,
label=pa.array([0, 1, 0, 0, 1]),
params=dummy_dataset_params(),
feature_name=["c", "d"],
categorical_feature=["c"],
)
booster = lgb.train({"num_leaves": 7}, dataset, num_boost_round=5)
assert booster.feature_name() == ["c", "d"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment