Unverified Commit f981fba7 authored by Yusuke Horibe's avatar Yusuke Horibe Committed by GitHub
Browse files

[python-package] add PyArrow Table to get_data (#6911)


Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent eff6adf2
......@@ -3267,6 +3267,8 @@ class Dataset:
self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices]
elif isinstance(self.data, pa_Table):
self.data = self.data.take(self.used_indices)
elif _is_list_of_sequences(self.data) and len(self.data) > 0:
self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices)))
else:
......
......@@ -118,7 +118,7 @@ def generate_random_arrow_array(
chunks = [chunk for chunk in chunks if len(chunk) > 0]
# Turn chunks into array
return pa.chunked_array([data], type=pa.float32())
return pa.chunked_array(chunks, type=pa.float32())
def dummy_dataset_params() -> Dict[str, Any]:
......@@ -456,6 +456,70 @@ def test_arrow_feature_name_manual():
assert booster.feature_name() == ["c", "d"]
def pyarrow_array_equal(arr1: pa.ChunkedArray, arr2: pa.ChunkedArray) -> bool:
"""Similar to ``np.array_equal()``, but for ``pyarrow.Array`` objects.
``pyarrow.Array`` objects with identical values do not compare equal if any of those
values are nulls. This function treats them as equal.
"""
if len(arr1) != len(arr2):
return False
np1 = arr1.to_numpy()
np2 = arr2.to_numpy()
return np.array_equal(np1, np2, equal_nan=True)
def test_get_data_arrow_table():
original_table = generate_simple_arrow_table()
dataset = lgb.Dataset(original_table, free_raw_data=False)
dataset.construct()
returned_data = dataset.get_data()
assert isinstance(returned_data, pa.Table)
assert returned_data.schema == original_table.schema
assert returned_data.shape == original_table.shape
for column_name in original_table.column_names:
original_column = original_table[column_name]
returned_column = returned_data[column_name]
assert original_column.type == returned_column.type
assert original_column.num_chunks == returned_column.num_chunks
assert pyarrow_array_equal(original_column, returned_column)
for i in range(original_column.num_chunks):
original_chunk_array = pa.chunked_array([original_column.chunk(i)])
returned_chunk_array = pa.chunked_array([returned_column.chunk(i)])
assert pyarrow_array_equal(original_chunk_array, returned_chunk_array)
def test_get_data_arrow_table_subset(rng):
original_table = generate_random_arrow_table(num_columns=3, num_datapoints=1000, seed=42)
dataset = lgb.Dataset(original_table, free_raw_data=False)
dataset.construct()
subset_size = 100
used_indices = rng.choice(a=original_table.shape[0], size=subset_size, replace=False)
used_indices = sorted(used_indices)
subset_dataset = dataset.subset(used_indices).construct()
expected_subset = original_table.take(used_indices)
subset_data = subset_dataset.get_data()
assert isinstance(subset_data, pa.Table)
assert subset_data.schema == expected_subset.schema
assert subset_data.shape == expected_subset.shape
assert len(subset_data) == len(used_indices)
assert subset_data.shape == (subset_size, 3)
for column_name in expected_subset.column_names:
expected_col = expected_subset[column_name]
returned_col = subset_data[column_name]
assert expected_col.type == returned_col.type
assert pyarrow_array_equal(expected_col, returned_col)
def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi):
with pytest.raises(
lgb.basic.LightGBMError, match="Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment