Unverified Commit d84582b7 authored by Oliver Borchert's avatar Oliver Borchert Committed by GitHub
Browse files

Fix null handling for Arrow data (#6227)

parent f5b6bd60
...@@ -144,7 +144,7 @@ struct ArrayIndexAccessor { ...@@ -144,7 +144,7 @@ struct ArrayIndexAccessor {
// - The structure of validity bitmasks is taken from here: // - The structure of validity bitmasks is taken from here:
// https://arrow.apache.org/docs/format/Columnar.html#validity-bitmaps // https://arrow.apache.org/docs/format/Columnar.html#validity-bitmaps
// - If the bitmask is NULL, all indices are valid // - If the bitmask is NULL, all indices are valid
if (validity == nullptr || !(validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) { if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) {
// In case the index is valid, we take it from the data buffer // In case the index is valid, we take it from the data buffer
auto data = static_cast<const T*>(array->buffers[1]); auto data = static_cast<const T*>(array->buffers[1]);
return static_cast<double>(data[buffer_idx]); return static_cast<double>(data[buffer_idx]);
......
...@@ -41,10 +41,12 @@ class ArrowChunkedArrayTest : public testing::Test { ...@@ -41,10 +41,12 @@ class ArrowChunkedArrayTest : public testing::Test {
// 1) Create validity bitmap // 1) Create validity bitmap
char* validity = nullptr; char* validity = nullptr;
if (!null_indices.empty()) { if (!null_indices.empty()) {
validity = static_cast<char*>(calloc(values.size() + sizeof(char) - 1, sizeof(char))); auto num_bytes = (values.size() + 7) / 8;
validity = static_cast<char*>(calloc(num_bytes, sizeof(char)));
memset(validity, 0xff, num_bytes * sizeof(char));
for (size_t i = 0; i < values.size(); ++i) { for (size_t i = 0; i < values.size(); ++i) {
if (std::find(null_indices.begin(), null_indices.end(), i) != null_indices.end()) { if (std::find(null_indices.begin(), null_indices.end(), i) != null_indices.end()) {
validity[i / 8] |= (1 << (i % 8)); validity[i / 8] &= ~(1 << (i % 8));
} }
} }
} }
......
...@@ -46,6 +46,16 @@ def generate_simple_arrow_table() -> pa.Table: ...@@ -46,6 +46,16 @@ def generate_simple_arrow_table() -> pa.Table:
return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))]) return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))])
def generate_nullable_arrow_table() -> pa.Table:
columns = [
pa.chunked_array([[1, None, 3, 4, 5]], type=pa.float32()),
pa.chunked_array([[None, 2, 3, 4, 5]], type=pa.float32()),
pa.chunked_array([[1, 2, 3, 4, None]], type=pa.float32()),
pa.chunked_array([[None, None, None, None, None]], type=pa.float32()),
]
return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))])
def generate_dummy_arrow_table() -> pa.Table: def generate_dummy_arrow_table() -> pa.Table:
col1 = pa.chunked_array([[1, 2, 3], [4, 5]], type=pa.uint8()) col1 = pa.chunked_array([[1, 2, 3], [4, 5]], type=pa.uint8())
col2 = pa.chunked_array([[0.5, 0.6], [0.1, 0.8, 1.5]], type=pa.float32()) col2 = pa.chunked_array([[0.5, 0.6], [0.1, 0.8, 1.5]], type=pa.float32())
...@@ -95,6 +105,7 @@ def dummy_dataset_params() -> Dict[str, Any]: ...@@ -95,6 +105,7 @@ def dummy_dataset_params() -> Dict[str, Any]:
[ # Use lambda functions here to minimize memory consumption [ # Use lambda functions here to minimize memory consumption
(lambda: generate_simple_arrow_table(), dummy_dataset_params()), (lambda: generate_simple_arrow_table(), dummy_dataset_params()),
(lambda: generate_dummy_arrow_table(), dummy_dataset_params()), (lambda: generate_dummy_arrow_table(), dummy_dataset_params()),
(lambda: generate_nullable_arrow_table(), dummy_dataset_params()),
(lambda: generate_random_arrow_table(3, 1000, 42), {}), (lambda: generate_random_arrow_table(3, 1000, 42), {}),
(lambda: generate_random_arrow_table(100, 10000, 43), {}), (lambda: generate_random_arrow_table(100, 10000, 43), {}),
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment