Unverified Commit b7f6311f authored by Oliver Borchert's avatar Oliver Borchert Committed by GitHub
Browse files

[python-package] Allow to pass Arrow array as labels (#6163)

parent 16004228
......@@ -555,6 +555,23 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
int num_element,
int type);
/*!
* \brief Set vector to a content in info.
* \note
* - \a label convert input datatype into ``float32``.
* \param handle Handle of dataset
* \param field_name Field name, can be \a label
* \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle,
const char* field_name,
int64_t n_chunks,
const ArrowArray* chunks,
const ArrowSchema* schema);
/*!
* \brief Get info vector from dataset.
* \param handle Handle of dataset
......
......@@ -110,6 +110,7 @@ class Metadata {
const std::vector<data_size_t>& used_data_indices);
void SetLabel(const label_t* label, data_size_t len);
void SetLabel(const ArrowChunkedArray& array);
void SetWeights(const label_t* weights, data_size_t len);
......@@ -334,6 +335,9 @@ class Metadata {
void CalculateQueryBoundaries();
/*! \brief Insert labels at the given index */
void InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len);
/*! \brief Set labels from pointers to the first element and the end of an iterator. */
template <typename It>
void SetLabelsFromIterator(It first, It last);
/*! \brief Insert weights at the given index */
void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len);
/*! \brief Insert initial scores at the given index */
......@@ -655,6 +659,8 @@ class Dataset {
LIGHTGBM_EXPORT void FinishLoad();
bool SetFieldFromArrow(const char* field_name, const ArrowChunkedArray& ca);
LIGHTGBM_EXPORT bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element);
LIGHTGBM_EXPORT bool SetDoubleField(const char* field_name, const double* field_data, data_size_t num_element);
......
......@@ -19,7 +19,7 @@ import numpy as np
import scipy.sparse
from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat,
dt_DataTable, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series)
dt_DataTable, pa_Array, pa_ChunkedArray, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series)
from .libpath import find_lib_path
if TYPE_CHECKING:
......@@ -99,7 +99,9 @@ _LGBM_LabelType = Union[
List[int],
np.ndarray,
pd_Series,
pd_DataFrame
pd_DataFrame,
pa_Array,
pa_ChunkedArray,
]
_LGBM_PredictDataType = Union[
str,
......@@ -353,6 +355,11 @@ def _is_2d_collection(data: Any) -> bool:
)
def _is_pyarrow_array(data: Any) -> bool:
"""Check whether data is a PyArrow array."""
return isinstance(data, (pa_Array, pa_ChunkedArray))
def _is_pyarrow_table(data: Any) -> bool:
"""Check whether data is a PyArrow table."""
return isinstance(data, pa_Table)
......@@ -384,7 +391,11 @@ class _ArrowCArray:
def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray:
"""Export an Arrow type to its C representation."""
# Obtain objects to export
if isinstance(data, pa_Table):
if isinstance(data, pa_Array):
export_objects = [data]
elif isinstance(data, pa_ChunkedArray):
export_objects = data.chunks
elif isinstance(data, pa_Table):
export_objects = data.to_batches()
else:
raise ValueError(f"data of type '{type(data)}' cannot be exported to Arrow")
......@@ -1620,7 +1631,7 @@ class Dataset:
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table
Data source of Dataset.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data.
reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference.
......@@ -2402,7 +2413,7 @@ class Dataset:
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source of Dataset.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Weight for each instance. Weights should be non-negative.
......@@ -2519,7 +2530,7 @@ class Dataset:
def set_field(
self,
field_name: str,
data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame]]
data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Array, pa_ChunkedArray]]
) -> "Dataset":
"""Set property into the Dataset.
......@@ -2527,7 +2538,7 @@ class Dataset:
----------
field_name : str
The field name of the information.
data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None
data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray or None
The data to be set.
Returns
......@@ -2546,6 +2557,20 @@ class Dataset:
ctypes.c_int(0),
ctypes.c_int(_FIELD_TYPE_MAPPER[field_name])))
return self
# If the data is a arrow data, we can just pass it to C
if _is_pyarrow_array(data):
c_array = _export_arrow_to_c(data)
_safe_call(_LIB.LGBM_DatasetSetFieldFromArrow(
self._handle,
_c_str(field_name),
ctypes.c_int64(c_array.n_chunks),
ctypes.c_void_p(c_array.chunks_ptr),
ctypes.c_void_p(c_array.schema_ptr),
))
self.version += 1
return self
dtype: "np.typing.DTypeLike"
if field_name == 'init_score':
dtype = np.float64
......@@ -2749,7 +2774,7 @@ class Dataset:
Parameters
----------
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None
The label information to be set into Dataset.
Returns
......@@ -2774,6 +2799,8 @@ class Dataset:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
label = label.to_numpy(dtype=np.float32, na_value=np.nan)
label_array = np.ravel(label)
elif _is_pyarrow_array(label):
label_array = label
else:
label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label')
self.set_field('label', label_array)
......@@ -4353,7 +4380,7 @@ class Booster:
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source for refit.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
label : list, numpy 1-D array or pandas Series / one-column DataFrame
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or pyarrow ChunkedArray
Label for refit.
decay_rate : float, optional (default=0.9)
Decay rate of refit,
......
......@@ -187,6 +187,8 @@ except ImportError:
"""pyarrow"""
try:
from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_floating as arrow_is_floating
......@@ -195,6 +197,18 @@ try:
except ImportError:
PYARROW_INSTALLED = False
class pa_Array: # type: ignore
"""Dummy class for pa.Array."""
def __init__(self, *args, **kwargs):
pass
class pa_ChunkedArray: # type: ignore
"""Dummy class for pa.ChunkedArray."""
def __init__(self, *args, **kwargs):
pass
class pa_Table: # type: ignore
"""Dummy class for pa.Table."""
......
......@@ -833,6 +833,7 @@ class Booster {
// explicitly declare symbols from LightGBM namespace
using LightGBM::AllgatherFunction;
using LightGBM::ArrowChunkedArray;
using LightGBM::ArrowTable;
using LightGBM::Booster;
using LightGBM::Common::CheckElementsIntervalClosed;
......@@ -1780,6 +1781,21 @@ int LGBM_DatasetSetField(DatasetHandle handle,
API_END();
}
int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle,
const char* field_name,
int64_t n_chunks,
const ArrowArray* chunks,
const ArrowSchema* schema) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
ArrowChunkedArray ca(n_chunks, chunks, schema);
auto is_success = dataset->SetFieldFromArrow(field_name, ca);
if (!is_success) {
Log::Fatal("Input field is not supported");
}
API_END();
}
int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name,
int* out_len,
......
......@@ -897,6 +897,17 @@ void Dataset::CopySubrow(const Dataset* fullset,
#endif // USE_CUDA
}
bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray &ca) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
metadata_.SetLabel(ca);
} else {
return false;
}
return true;
}
bool Dataset::SetFloatField(const char* field_name, const float* field_data,
data_size_t num_element) {
std::string name(field_name);
......
......@@ -403,27 +403,39 @@ void Metadata::InsertInitScores(const double* init_scores, data_size_t start_ind
// CUDA is handled after all insertions are complete
}
void Metadata::SetLabel(const label_t* label, data_size_t len) {
template <typename It>
void Metadata::SetLabelsFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_);
if (label == nullptr) {
Log::Fatal("label cannot be nullptr");
if (num_data_ != last - first) {
Log::Fatal("Length of labels differs from the length of #data");
}
if (num_data_ != len) {
Log::Fatal("Length of label is not same with #data");
if (label_.empty()) {
label_.resize(num_data_);
}
if (label_.empty()) { label_.resize(num_data_); }
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_data_ >= 1024)
for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = Common::AvoidInf(label[i]);
label_[i] = Common::AvoidInf(first[i]);
}
#ifdef USE_CUDA
if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetLabel(label_.data(), len);
cuda_metadata_->SetLabel(label_.data(), label_.size());
}
#endif // USE_CUDA
}
void Metadata::SetLabel(const label_t* label, data_size_t len) {
if (label == nullptr) {
Log::Fatal("label cannot be nullptr");
}
SetLabelsFromIterator(label, label + len);
}
void Metadata::SetLabel(const ArrowChunkedArray& array) {
SetLabelsFromIterator(array.begin<label_t>(), array.end<label_t>());
}
void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len) {
if (labels == nullptr) {
Log::Fatal("label cannot be nullptr");
......
......@@ -67,6 +67,10 @@ def dummy_dataset_params() -> Dict[str, Any]:
}
def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray):
assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs)
# ----------------------------------------------------------------------------------------------- #
# UNIT TESTS #
# ----------------------------------------------------------------------------------------------- #
......@@ -97,3 +101,45 @@ def test_dataset_construct_fuzzy(
arrow_dataset._dump_text(tmp_path / "arrow.txt")
pandas_dataset._dump_text(tmp_path / "pandas.txt")
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")
@pytest.mark.parametrize(
["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
)
@pytest.mark.parametrize(
"arrow_type",
[
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64(),
pa.float32(),
pa.float64(),
],
)
def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any):
data = generate_dummy_arrow_table()
labels = array_type(label_data, type=arrow_type)
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
dataset.construct()
expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
assert_arrays_equal(expected, dataset.get_label())
def test_dataset_construct_labels_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)
arrow_dataset = lgb.Dataset(arrow_table, label=arrow_array)
arrow_dataset.construct()
pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), label=arrow_array.to_numpy())
pandas_dataset.construct()
assert_arrays_equal(arrow_dataset.get_label(), pandas_dataset.get_label())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment