"python-package/vscode:/vscode.git/clone" did not exist on "595c10ab26165691502dc27219b454b21c24bd1d"
Unverified Commit faba8177 authored by Oliver Borchert's avatar Oliver Borchert Committed by GitHub
Browse files

[python-package] Allow to pass Arrow table with boolean columns to dataset (#6353)

parent 0a3e1a55
......@@ -84,6 +84,10 @@ class ArrowChunkedArray {
const ArrowSchema* schema_;
/* List of length `n + 1` for `n` chunks containing the offsets for each chunk. */
std::vector<int64_t> chunk_offsets_;
/* Indicator whether this chunked array needs to call the arrays' release callbacks.
NOTE: This is MUST only be set to `true` if this chunked array is not part of a
`ArrowTable` as children arrays may not be released by the consumer (see below). */
const bool releases_arrow_;
inline void construct_chunk_offsets() {
chunk_offsets_.reserve(chunks_.size() + 1);
......@@ -100,7 +104,8 @@ class ArrowChunkedArray {
* @param chunks A list with the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowChunkedArray(std::vector<const ArrowArray*> chunks, const ArrowSchema* schema) {
inline ArrowChunkedArray(std::vector<const ArrowArray*> chunks, const ArrowSchema* schema)
: releases_arrow_(false) {
chunks_ = chunks;
schema_ = schema;
construct_chunk_offsets();
......@@ -113,9 +118,9 @@ class ArrowChunkedArray {
* @param chunks A C-style array containing the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowChunkedArray(int64_t n_chunks,
const struct ArrowArray* chunks,
const struct ArrowSchema* schema) {
inline ArrowChunkedArray(int64_t n_chunks, const struct ArrowArray* chunks,
const struct ArrowSchema* schema)
: releases_arrow_(true) {
chunks_.reserve(n_chunks);
for (auto k = 0; k < n_chunks; ++k) {
if (chunks[k].length == 0) continue;
......@@ -125,6 +130,21 @@ class ArrowChunkedArray {
construct_chunk_offsets();
}
~ArrowChunkedArray() {
if (!releases_arrow_) {
return;
}
for (size_t i = 0; i < chunks_.size(); ++i) {
auto chunk = chunks_[i];
if (chunk->release) {
chunk->release(const_cast<ArrowArray*>(chunk));
}
}
if (schema_->release) {
schema_->release(const_cast<ArrowSchema*>(schema_));
}
}
/**
* @brief Get the length of the chunked array.
* This method returns the cumulative length of all chunks.
......@@ -219,7 +239,7 @@ class ArrowTable {
* @param chunks A C-style array containing the chunks.
* @param schema The schema for all chunks.
*/
inline ArrowTable(int64_t n_chunks, const ArrowArray *chunks, const ArrowSchema *schema)
inline ArrowTable(int64_t n_chunks, const ArrowArray* chunks, const ArrowSchema* schema)
: n_chunks_(n_chunks), chunks_ptr_(chunks), schema_ptr_(schema) {
columns_.reserve(schema->n_children);
for (int64_t j = 0; j < schema->n_children; ++j) {
......@@ -236,7 +256,8 @@ class ArrowTable {
~ArrowTable() {
// As consumer of the Arrow array, the Arrow table must release all Arrow arrays it receives
// as well as the schema. As per the specification, children arrays are released by the
// producer. See: https://arrow.apache.org/docs/format/CDataInterface.html#release-callback-semantics-for-consumers
// producer. See:
// https://arrow.apache.org/docs/format/CDataInterface.html#release-callback-semantics-for-consumers
for (int64_t i = 0; i < n_chunks_; ++i) {
auto chunk = &chunks_ptr_[i];
if (chunk->release) {
......
......@@ -31,8 +31,7 @@ inline ArrowChunkedArray::Iterator<T> ArrowChunkedArray::end() const {
/* ---------------------------------- ITERATOR IMPLEMENTATION ---------------------------------- */
template <typename T>
ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array,
getter_fn get,
ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array, getter_fn get,
int64_t ptr_chunk)
: array_(array), get_(get), ptr_chunk_(ptr_chunk) {
this->ptr_offset_ = 0;
......@@ -41,7 +40,7 @@ ArrowChunkedArray::Iterator<T>::Iterator(const ArrowChunkedArray& array,
template <typename T>
T ArrowChunkedArray::Iterator<T>::operator*() const {
auto chunk = array_.chunks_[ptr_chunk_];
return static_cast<T>(get_(chunk, ptr_offset_));
return get_(chunk, ptr_offset_);
}
template <typename T>
......@@ -54,7 +53,7 @@ T ArrowChunkedArray::Iterator<T>::operator[](I idx) const {
auto chunk = array_.chunks_[chunk_idx];
auto ptr_offset = static_cast<int64_t>(idx) - array_.chunk_offsets_[chunk_idx];
return static_cast<T>(get_(chunk, ptr_offset));
return get_(chunk, ptr_offset);
}
template <typename T>
......@@ -147,11 +146,28 @@ struct ArrayIndexAccessor {
if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) {
// In case the index is valid, we take it from the data buffer
auto data = static_cast<const T*>(array->buffers[1]);
return static_cast<double>(data[buffer_idx]);
return static_cast<V>(data[buffer_idx]);
}
// In case the index is not valid, we return a default value
return arrow_primitive_missing_value<T>();
return arrow_primitive_missing_value<V>();
}
};
template <typename V>
struct ArrayIndexAccessor<bool, V> {
V operator()(const ArrowArray* array, size_t idx) {
// Custom implementation for booleans as values are bit-packed:
// https://arrow.apache.org/docs/cpp/api/datatype.html#_CPPv4N5arrow4Type4type4BOOLE
auto buffer_idx = idx + array->offset;
auto validity = static_cast<const char*>(array->buffers[0]);
if (validity == nullptr || (validity[buffer_idx / 8] & (1 << (buffer_idx % 8)))) {
// In case the index is valid, we have to take the appropriate bit from the buffer
auto data = static_cast<const char*>(array->buffers[1]);
auto value = (data[buffer_idx / 8] & (1 << (buffer_idx % 8))) >> (buffer_idx % 8);
return static_cast<V>(value);
}
return arrow_primitive_missing_value<V>();
}
};
......@@ -180,6 +196,8 @@ std::function<T(const ArrowArray*, size_t)> get_index_accessor(const char* dtype
return ArrayIndexAccessor<float, T>();
case 'g':
return ArrayIndexAccessor<double, T>();
case 'b':
return ArrayIndexAccessor<bool, T>();
default:
throw std::invalid_argument("unsupported Arrow datatype");
}
......
......@@ -22,6 +22,7 @@ from .compat import (
PANDAS_INSTALLED,
PYARROW_INSTALLED,
arrow_cffi,
arrow_is_boolean,
arrow_is_floating,
arrow_is_integer,
concat,
......@@ -1688,7 +1689,7 @@ class _InnerPredictor:
raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.")
# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
raise ValueError("Arrow table may only have integer or floating point datatypes")
# Prepare prediction output array
......@@ -2435,7 +2436,7 @@ class Dataset:
raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.")
# Check that the input is valid: we only handle numbers (for now)
if not all(arrow_is_integer(t) or arrow_is_floating(t) for t in table.schema.types):
if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types):
raise ValueError("Arrow table may only have integer or floating point datatypes")
# Export Arrow table to C
......
......@@ -222,6 +222,7 @@ try:
from pyarrow import Table as pa_Table
from pyarrow import chunked_array as pa_chunked_array
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_boolean as arrow_is_boolean
from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer
......@@ -265,6 +266,7 @@ except ImportError:
equal = None
pa_chunked_array = None
arrow_is_boolean = None
arrow_is_integer = None
arrow_is_floating = None
......
......@@ -5,87 +5,151 @@
* Author: Oliver Borchert
*/
#include <gtest/gtest.h>
#include <LightGBM/arrow.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <cmath>
#include <cstdlib>
using LightGBM::ArrowChunkedArray;
using LightGBM::ArrowTable;
/* --------------------------------------------------------------------------------------------- */
/* UTILS */
/* --------------------------------------------------------------------------------------------- */
// This code is copied and adapted from the official Arrow producer examples:
// https://arrow.apache.org/docs/format/CDataInterface.html#exporting-a-struct-float32-utf8-array
static void release_schema(struct ArrowSchema* schema) {
// Free children
if (schema->children) {
for (int64_t i = 0; i < schema->n_children; ++i) {
struct ArrowSchema* child = schema->children[i];
if (child->release) {
child->release(child);
}
free(child);
}
free(schema->children);
}
// Finalize
schema->release = nullptr;
}
static void release_array(struct ArrowArray* array) {
// Free children
if (array->children) {
for (int64_t i = 0; i < array->n_children; ++i) {
struct ArrowArray* child = array->children[i];
if (child->release) {
child->release(child);
}
free(child);
}
free(array->children);
}
// Free buffers
for (int64_t i = 0; i < array->n_buffers; ++i) {
if (array->buffers[i]) {
free(const_cast<void*>(array->buffers[i]));
}
}
free(array->buffers);
// Finalize
array->release = nullptr;
}
/* ------------------------------------------ PRODUCER ----------------------------------------- */
class ArrowChunkedArrayTest : public testing::Test {
protected:
void SetUp() override {}
ArrowArray created_nested_array(const std::vector<ArrowArray*>& arrays) {
/* -------------------------------------- ARRAY CREATION ------------------------------------- */
char* build_validity_bitmap(int64_t size, std::vector<int64_t> null_indices = {}) {
if (null_indices.empty()) {
return nullptr;
}
auto num_bytes = (size + 7) / 8;
auto validity = static_cast<char*>(malloc(num_bytes * sizeof(char)));
memset(validity, 0xff, num_bytes * sizeof(char));
for (auto idx : null_indices) {
validity[idx / 8] &= ~(1 << (idx % 8));
}
return validity;
}
ArrowArray build_primitive_array(void* data, int64_t size, int64_t offset,
std::vector<int64_t> null_indices) {
const void** buffers = (const void**)malloc(sizeof(void*) * 2);
buffers[0] = build_validity_bitmap(size, null_indices);
buffers[1] = data;
ArrowArray arr;
arr.buffers = nullptr;
arr.children = (ArrowArray**)arrays.data(); // NOLINT
arr.length = size - offset;
arr.null_count = static_cast<int64_t>(null_indices.size());
arr.offset = offset;
arr.n_buffers = 2;
arr.n_children = 0;
arr.buffers = buffers;
arr.children = nullptr;
arr.dictionary = nullptr;
arr.length = arrays[0]->length;
arr.n_buffers = 0;
arr.n_children = arrays.size();
arr.null_count = 0;
arr.offset = 0;
arr.release = &release_array;
arr.private_data = nullptr;
arr.release = nullptr;
return arr;
}
template <typename T>
ArrowArray create_primitive_array(const std::vector<T>& values,
int64_t offset = 0,
ArrowArray create_primitive_array(const std::vector<T>& values, int64_t offset = 0,
std::vector<int64_t> null_indices = {}) {
// NOTE: Arrow arrays have 64-bit alignment but we can safely ignore this in tests
// 1) Create validity bitmap
char* validity = nullptr;
if (!null_indices.empty()) {
auto num_bytes = (values.size() + 7) / 8;
validity = static_cast<char*>(calloc(num_bytes, sizeof(char)));
memset(validity, 0xff, num_bytes * sizeof(char));
for (size_t i = 0; i < values.size(); ++i) {
if (std::find(null_indices.begin(), null_indices.end(), i) != null_indices.end()) {
validity[i / 8] &= ~(1 << (i % 8));
}
auto buffer = static_cast<T*>(malloc(sizeof(T) * values.size()));
for (size_t i = 0; i < values.size(); ++i) {
buffer[i] = values[i];
}
return build_primitive_array(buffer, values.size(), offset, null_indices);
}
ArrowArray create_primitive_array(const std::vector<bool>& values, int64_t offset = 0,
std::vector<int64_t> null_indices = {}) {
auto num_bytes = (values.size() + 7) / 8;
auto buffer = static_cast<char*>(calloc(sizeof(char), num_bytes));
for (size_t i = 0; i < values.size(); ++i) {
// By using `calloc` above, we only need to set 'true' values
if (values[i]) {
buffer[i / 8] |= (1 << (i % 8));
}
}
return build_primitive_array(buffer, values.size(), offset, null_indices);
}
// 2) Create buffers
const void** buffers = (const void**)malloc(sizeof(void*) * 2);
buffers[0] = validity;
buffers[1] = values.data() + offset;
ArrowArray created_nested_array(const std::vector<ArrowArray*>& arrays) {
auto children = static_cast<ArrowArray**>(malloc(sizeof(ArrowArray*) * arrays.size()));
for (size_t i = 0; i < arrays.size(); ++i) {
auto child = static_cast<ArrowArray*>(malloc(sizeof(ArrowArray)));
*child = *arrays[i];
children[i] = child;
}
// Create arrow array
ArrowArray arr;
arr.buffers = buffers;
arr.children = nullptr;
arr.dictionary = nullptr;
arr.length = values.size() - offset;
arr.length = children[0]->length;
arr.null_count = 0;
arr.offset = 0;
arr.n_buffers = 0;
arr.n_children = static_cast<int64_t>(arrays.size());
arr.buffers = nullptr;
arr.children = children;
arr.dictionary = nullptr;
arr.release = &release_array;
arr.private_data = nullptr;
arr.release = [](ArrowArray* arr) {
if (arr->buffers[0] != nullptr)
free((void*)(arr->buffers[0])); // NOLINT
free((void*)(arr->buffers)); // NOLINT
};
return arr;
}
ArrowSchema create_nested_schema(const std::vector<ArrowSchema*>& arrays) {
ArrowSchema schema;
schema.format = "+s";
schema.name = nullptr;
schema.metadata = nullptr;
schema.flags = 0;
schema.n_children = arrays.size();
schema.children = (ArrowSchema**)arrays.data(); // NOLINT
schema.dictionary = nullptr;
schema.private_data = nullptr;
schema.release = nullptr;
return schema;
}
/* ------------------------------------- SCHEMA CREATION ------------------------------------- */
template <typename T>
ArrowSchema create_primitive_schema() {
......@@ -102,27 +166,71 @@ class ArrowChunkedArrayTest : public testing::Test {
schema.n_children = 0;
schema.children = nullptr;
schema.dictionary = nullptr;
schema.release = nullptr;
schema.private_data = nullptr;
return schema;
}
template <>
ArrowSchema create_primitive_schema<bool>() {
ArrowSchema schema;
schema.format = "b";
schema.name = nullptr;
schema.metadata = nullptr;
schema.flags = 0;
schema.n_children = 0;
schema.children = nullptr;
schema.dictionary = nullptr;
schema.release = nullptr;
schema.private_data = nullptr;
return schema;
}
ArrowSchema create_nested_schema(const std::vector<ArrowSchema*>& arrays) {
auto children = static_cast<ArrowSchema**>(malloc(sizeof(ArrowSchema*) * arrays.size()));
for (size_t i = 0; i < arrays.size(); ++i) {
auto child = static_cast<ArrowSchema*>(malloc(sizeof(ArrowSchema)));
*child = *arrays[i];
children[i] = child;
}
ArrowSchema schema;
schema.format = "+s";
schema.name = nullptr;
schema.metadata = nullptr;
schema.flags = 0;
schema.n_children = static_cast<int64_t>(arrays.size());
schema.children = children;
schema.dictionary = nullptr;
schema.release = &release_schema;
schema.private_data = nullptr;
return schema;
}
};
/* --------------------------------------------------------------------------------------------- */
/* TESTS */
/* --------------------------------------------------------------------------------------------- */
TEST_F(ArrowChunkedArrayTest, GetLength) {
auto schema = create_primitive_schema<float>();
std::vector<float> dat1 = {1, 2};
auto arr1 = create_primitive_array(dat1);
ArrowChunkedArray ca1(1, &arr1, nullptr);
ArrowChunkedArray ca1(1, &arr1, &schema);
ASSERT_EQ(ca1.get_length(), 2);
std::vector<float> dat2 = {3, 4, 5, 6};
auto arr2 = create_primitive_array<float>(dat2);
ArrowArray arrs[2] = {arr1, arr2};
ArrowChunkedArray ca2(2, arrs, nullptr);
auto arr2 = create_primitive_array(dat1);
auto arr3 = create_primitive_array(dat2);
ArrowArray arrs[2] = {arr2, arr3};
ArrowChunkedArray ca2(2, arrs, &schema);
ASSERT_EQ(ca2.get_length(), 6);
arr1.release(&arr1);
arr2.release(&arr2);
std::vector<bool> dat3 = {true, false, true, true};
auto arr4 = create_primitive_array(dat3, 1);
ArrowChunkedArray ca3(1, &arr4, &schema);
ASSERT_EQ(ca3.get_length(), 3);
}
TEST_F(ArrowChunkedArrayTest, GetColumns) {
......@@ -149,18 +257,15 @@ TEST_F(ArrowChunkedArrayTest, GetColumns) {
auto ca2 = table.get_column(1);
ASSERT_EQ(ca2.get_length(), 3);
ASSERT_EQ(*ca2.begin<int32_t>(), 4);
arr1.release(&arr1);
arr2.release(&arr2);
}
TEST_F(ArrowChunkedArrayTest, IteratorArithmetic) {
std::vector<float> dat1 = {1, 2};
auto arr1 = create_primitive_array<float>(dat1);
auto arr1 = create_primitive_array(dat1);
std::vector<float> dat2 = {3, 4, 5, 6};
auto arr2 = create_primitive_array<float>(dat2);
auto arr2 = create_primitive_array(dat2);
std::vector<float> dat3 = {7};
auto arr3 = create_primitive_array<float>(dat3);
auto arr3 = create_primitive_array(dat3);
auto schema = create_primitive_schema<float>();
ArrowArray arrs[3] = {arr1, arr2, arr3};
......@@ -190,15 +295,39 @@ TEST_F(ArrowChunkedArrayTest, IteratorArithmetic) {
auto end = ca.end<int32_t>();
ASSERT_EQ(end - it, 2);
ASSERT_EQ(end - ca.begin<int32_t>(), 7);
}
TEST_F(ArrowChunkedArrayTest, BooleanIterator) {
std::vector<bool> dat1 = {false, true, false};
auto arr1 = create_primitive_array(dat1, 0, {2});
std::vector<bool> dat2 = {false, false, false, false, true, true, true, true, false, true};
auto arr2 = create_primitive_array(dat2, 1);
auto schema = create_primitive_schema<bool>();
ArrowArray arrs[2] = {arr1, arr2};
ArrowChunkedArray ca(2, arrs, &schema);
// Check for values in first chunk
auto it = ca.begin<float>();
ASSERT_EQ(*it, 0);
ASSERT_EQ(*(++it), 1);
ASSERT_TRUE(std::isnan(*(++it)));
// Check for some values in second chunk
ASSERT_EQ(*(++it), 0);
it += 3;
ASSERT_EQ(*it, 1);
it += 4;
ASSERT_EQ(*it, 0);
ASSERT_EQ(*(++it), 1);
arr1.release(&arr1);
arr2.release(&arr2);
arr2.release(&arr3);
// Check end
ASSERT_EQ(++it, ca.end<float>());
}
TEST_F(ArrowChunkedArrayTest, OffsetAndValidity) {
std::vector<float> dat = {0, 1, 2, 3, 4, 5, 6};
auto arr = create_primitive_array(dat, 2, {0, 1});
auto arr = create_primitive_array(dat, 2, {2, 3});
auto schema = create_primitive_schema<float>();
ArrowChunkedArray ca(1, &arr, &schema);
......
# coding: utf-8
import filecmp
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
......@@ -43,16 +44,17 @@ def generate_simple_arrow_table(empty_chunks: bool = False) -> pa.Table:
pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.int64()),
pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.float32()),
pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.float64()),
pa.chunked_array(c + [[True, True, False]] + c + [[False, True]] + c, type=pa.bool_()),
]
return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))])
def generate_nullable_arrow_table() -> pa.Table:
def generate_nullable_arrow_table(dtype: Any) -> pa.Table:
columns = [
pa.chunked_array([[1, None, 3, 4, 5]], type=pa.float32()),
pa.chunked_array([[None, 2, 3, 4, 5]], type=pa.float32()),
pa.chunked_array([[1, 2, 3, 4, None]], type=pa.float32()),
pa.chunked_array([[None, None, None, None, None]], type=pa.float32()),
pa.chunked_array([[1, None, 3, 4, 5]], type=dtype),
pa.chunked_array([[None, 2, 3, 4, 5]], type=dtype),
pa.chunked_array([[1, 2, 3, 4, None]], type=dtype),
pa.chunked_array([[None, None, None, None, None]], type=dtype),
]
return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))])
......@@ -120,13 +122,20 @@ def dummy_dataset_params() -> Dict[str, Any]:
# ------------------------------------------- DATASET ------------------------------------------- #
def assert_datasets_equal(tmp_path: Path, lhs: lgb.Dataset, rhs: lgb.Dataset):
lhs._dump_text(tmp_path / "arrow.txt")
rhs._dump_text(tmp_path / "pandas.txt")
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")
@pytest.mark.parametrize(
("arrow_table_fn", "dataset_params"),
[ # Use lambda functions here to minimize memory consumption
(lambda: generate_simple_arrow_table(), dummy_dataset_params()),
(lambda: generate_simple_arrow_table(empty_chunks=True), dummy_dataset_params()),
(lambda: generate_dummy_arrow_table(), dummy_dataset_params()),
(lambda: generate_nullable_arrow_table(), dummy_dataset_params()),
(lambda: generate_nullable_arrow_table(pa.float32()), dummy_dataset_params()),
(lambda: generate_nullable_arrow_table(pa.int32()), dummy_dataset_params()),
(lambda: generate_random_arrow_table(3, 1000, 42), {}),
(lambda: generate_random_arrow_table(100, 10000, 43), {}),
],
......@@ -140,9 +149,22 @@ def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):
pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), params=dataset_params)
pandas_dataset.construct()
arrow_dataset._dump_text(tmp_path / "arrow.txt")
pandas_dataset._dump_text(tmp_path / "pandas.txt")
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")
assert_datasets_equal(tmp_path, arrow_dataset, pandas_dataset)
def test_dataset_construct_fuzzy_boolean(tmp_path):
boolean_data = generate_random_arrow_table(10, 10000, 42, generate_nulls=False, values=np.array([True, False]))
float_schema = pa.schema([pa.field(f"col_{i}", pa.float32()) for i in range(len(boolean_data.columns))])
float_data = boolean_data.cast(float_schema)
arrow_dataset = lgb.Dataset(boolean_data)
arrow_dataset.construct()
pandas_dataset = lgb.Dataset(float_data.to_pandas())
pandas_dataset.construct()
assert_datasets_equal(tmp_path, arrow_dataset, pandas_dataset)
# -------------------------------------------- FIELDS ------------------------------------------- #
......@@ -195,6 +217,25 @@ def test_dataset_construct_labels(array_type, label_data, arrow_type):
np_assert_array_equal(expected, dataset.get_label(), strict=True)
@pytest.mark.parametrize(
["array_type", "label_data"],
[
(pa.array, [False, True, False, False, True]),
(pa.chunked_array, [[False], [True, False, False, True]]),
(pa.chunked_array, [[], [False], [True, False, False, True]]),
(pa.chunked_array, [[False], [], [True, False], [], [], [False, True], []]),
],
)
def test_dataset_construct_labels_boolean(array_type, label_data):
data = generate_dummy_arrow_table()
labels = array_type(label_data, type=pa.bool_())
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
dataset.construct()
expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
np_assert_array_equal(expected, dataset.get_label(), strict=True)
# ------------------------------------------- WEIGHTS ------------------------------------------- #
......@@ -317,7 +358,10 @@ def assert_equal_predict_arrow_pandas(booster: lgb.Booster, data: pa.Table):
def test_predict_regression():
data = generate_random_arrow_table(10, 10000, 42)
data_float = generate_random_arrow_table(10, 10000, 42)
data_bool = generate_random_arrow_table(1, 10000, 42, generate_nulls=False, values=np.array([True, False]))
data = pa.Table.from_arrays(data_float.columns + data_bool.columns, names=data_float.schema.names + ["col_bool"])
dataset = lgb.Dataset(
data,
label=generate_random_arrow_array(10000, 43, generate_nulls=False),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment