Unverified Commit b7f6311f authored by Oliver Borchert's avatar Oliver Borchert Committed by GitHub
Browse files

[python-package] Allow to pass Arrow array as labels (#6163)

parent 16004228
...@@ -555,6 +555,23 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -555,6 +555,23 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
int num_element, int num_element,
int type); int type);
/*!
* \brief Set vector to a content in info.
* \note
* - \a label convert input datatype into ``float32``.
* \param handle Handle of dataset
* \param field_name Field name, can be \a label
* \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle,
const char* field_name,
int64_t n_chunks,
const ArrowArray* chunks,
const ArrowSchema* schema);
/*! /*!
* \brief Get info vector from dataset. * \brief Get info vector from dataset.
* \param handle Handle of dataset * \param handle Handle of dataset
......
...@@ -110,6 +110,7 @@ class Metadata { ...@@ -110,6 +110,7 @@ class Metadata {
const std::vector<data_size_t>& used_data_indices); const std::vector<data_size_t>& used_data_indices);
void SetLabel(const label_t* label, data_size_t len); void SetLabel(const label_t* label, data_size_t len);
void SetLabel(const ArrowChunkedArray& array);
void SetWeights(const label_t* weights, data_size_t len); void SetWeights(const label_t* weights, data_size_t len);
...@@ -334,6 +335,9 @@ class Metadata { ...@@ -334,6 +335,9 @@ class Metadata {
void CalculateQueryBoundaries(); void CalculateQueryBoundaries();
/*! \brief Insert labels at the given index */ /*! \brief Insert labels at the given index */
void InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len); void InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len);
/*! \brief Set labels from pointers to the first element and the end of an iterator. */
template <typename It>
void SetLabelsFromIterator(It first, It last);
/*! \brief Insert weights at the given index */ /*! \brief Insert weights at the given index */
void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len); void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len);
/*! \brief Insert initial scores at the given index */ /*! \brief Insert initial scores at the given index */
...@@ -655,6 +659,8 @@ class Dataset { ...@@ -655,6 +659,8 @@ class Dataset {
LIGHTGBM_EXPORT void FinishLoad(); LIGHTGBM_EXPORT void FinishLoad();
bool SetFieldFromArrow(const char* field_name, const ArrowChunkedArray& ca);
LIGHTGBM_EXPORT bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element); LIGHTGBM_EXPORT bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element);
LIGHTGBM_EXPORT bool SetDoubleField(const char* field_name, const double* field_data, data_size_t num_element); LIGHTGBM_EXPORT bool SetDoubleField(const char* field_name, const double* field_data, data_size_t num_element);
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import scipy.sparse import scipy.sparse
from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat, from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat,
dt_DataTable, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series) dt_DataTable, pa_Array, pa_ChunkedArray, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series)
from .libpath import find_lib_path from .libpath import find_lib_path
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -99,7 +99,9 @@ _LGBM_LabelType = Union[ ...@@ -99,7 +99,9 @@ _LGBM_LabelType = Union[
List[int], List[int],
np.ndarray, np.ndarray,
pd_Series, pd_Series,
pd_DataFrame pd_DataFrame,
pa_Array,
pa_ChunkedArray,
] ]
_LGBM_PredictDataType = Union[ _LGBM_PredictDataType = Union[
str, str,
...@@ -353,6 +355,11 @@ def _is_2d_collection(data: Any) -> bool: ...@@ -353,6 +355,11 @@ def _is_2d_collection(data: Any) -> bool:
) )
def _is_pyarrow_array(data: Any) -> bool:
"""Check whether data is a PyArrow array."""
return isinstance(data, (pa_Array, pa_ChunkedArray))
def _is_pyarrow_table(data: Any) -> bool: def _is_pyarrow_table(data: Any) -> bool:
"""Check whether data is a PyArrow table.""" """Check whether data is a PyArrow table."""
return isinstance(data, pa_Table) return isinstance(data, pa_Table)
...@@ -384,7 +391,11 @@ class _ArrowCArray: ...@@ -384,7 +391,11 @@ class _ArrowCArray:
def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray: def _export_arrow_to_c(data: pa_Table) -> _ArrowCArray:
"""Export an Arrow type to its C representation.""" """Export an Arrow type to its C representation."""
# Obtain objects to export # Obtain objects to export
if isinstance(data, pa_Table): if isinstance(data, pa_Array):
export_objects = [data]
elif isinstance(data, pa_ChunkedArray):
export_objects = data.chunks
elif isinstance(data, pa_Table):
export_objects = data.to_batches() export_objects = data.to_batches()
else: else:
raise ValueError(f"data of type '{type(data)}' cannot be exported to Arrow") raise ValueError(f"data of type '{type(data)}' cannot be exported to Arrow")
...@@ -1620,7 +1631,7 @@ class Dataset: ...@@ -1620,7 +1631,7 @@ class Dataset:
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence, list of numpy array or pyarrow Table
Data source of Dataset. Data source of Dataset.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data. Label of the data.
reference : Dataset or None, optional (default=None) reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference. If this is Dataset for validation, training data should be used as reference.
...@@ -2402,7 +2413,7 @@ class Dataset: ...@@ -2402,7 +2413,7 @@ class Dataset:
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source of Dataset. Data source of Dataset.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data. Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None) weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Weight for each instance. Weights should be non-negative. Weight for each instance. Weights should be non-negative.
...@@ -2519,7 +2530,7 @@ class Dataset: ...@@ -2519,7 +2530,7 @@ class Dataset:
def set_field( def set_field(
self, self,
field_name: str, field_name: str,
data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame]] data: Optional[Union[List[List[float]], List[List[int]], List[float], List[int], np.ndarray, pd_Series, pd_DataFrame, pa_Array, pa_ChunkedArray]]
) -> "Dataset": ) -> "Dataset":
"""Set property into the Dataset. """Set property into the Dataset.
...@@ -2527,7 +2538,7 @@ class Dataset: ...@@ -2527,7 +2538,7 @@ class Dataset:
---------- ----------
field_name : str field_name : str
The field name of the information. The field name of the information.
data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None data : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), pyarrow Array, pyarrow ChunkedArray or None
The data to be set. The data to be set.
Returns Returns
...@@ -2546,6 +2557,20 @@ class Dataset: ...@@ -2546,6 +2557,20 @@ class Dataset:
ctypes.c_int(0), ctypes.c_int(0),
ctypes.c_int(_FIELD_TYPE_MAPPER[field_name]))) ctypes.c_int(_FIELD_TYPE_MAPPER[field_name])))
return self return self
# If the data is a arrow data, we can just pass it to C
if _is_pyarrow_array(data):
c_array = _export_arrow_to_c(data)
_safe_call(_LIB.LGBM_DatasetSetFieldFromArrow(
self._handle,
_c_str(field_name),
ctypes.c_int64(c_array.n_chunks),
ctypes.c_void_p(c_array.chunks_ptr),
ctypes.c_void_p(c_array.schema_ptr),
))
self.version += 1
return self
dtype: "np.typing.DTypeLike" dtype: "np.typing.DTypeLike"
if field_name == 'init_score': if field_name == 'init_score':
dtype = np.float64 dtype = np.float64
...@@ -2749,7 +2774,7 @@ class Dataset: ...@@ -2749,7 +2774,7 @@ class Dataset:
Parameters Parameters
---------- ----------
label : list, numpy 1-D array, pandas Series / one-column DataFrame or None label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None
The label information to be set into Dataset. The label information to be set into Dataset.
Returns Returns
...@@ -2774,6 +2799,8 @@ class Dataset: ...@@ -2774,6 +2799,8 @@ class Dataset:
# data has nullable dtypes, but we can specify na_value argument and copy will be made # data has nullable dtypes, but we can specify na_value argument and copy will be made
label = label.to_numpy(dtype=np.float32, na_value=np.nan) label = label.to_numpy(dtype=np.float32, na_value=np.nan)
label_array = np.ravel(label) label_array = np.ravel(label)
elif _is_pyarrow_array(label):
label_array = label
else: else:
label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label') label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label')
self.set_field('label', label_array) self.set_field('label', label_array)
...@@ -4353,7 +4380,7 @@ class Booster: ...@@ -4353,7 +4380,7 @@ class Booster:
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source for refit. Data source for refit.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
label : list, numpy 1-D array or pandas Series / one-column DataFrame label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array or pyarrow ChunkedArray
Label for refit. Label for refit.
decay_rate : float, optional (default=0.9) decay_rate : float, optional (default=0.9)
Decay rate of refit, Decay rate of refit,
......
...@@ -187,6 +187,8 @@ except ImportError: ...@@ -187,6 +187,8 @@ except ImportError:
"""pyarrow""" """pyarrow"""
try: try:
from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table from pyarrow import Table as pa_Table
from pyarrow.cffi import ffi as arrow_cffi from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_floating as arrow_is_floating from pyarrow.types import is_floating as arrow_is_floating
...@@ -195,6 +197,18 @@ try: ...@@ -195,6 +197,18 @@ try:
except ImportError: except ImportError:
PYARROW_INSTALLED = False PYARROW_INSTALLED = False
class pa_Array: # type: ignore
"""Dummy class for pa.Array."""
def __init__(self, *args, **kwargs):
pass
class pa_ChunkedArray: # type: ignore
"""Dummy class for pa.ChunkedArray."""
def __init__(self, *args, **kwargs):
pass
class pa_Table: # type: ignore class pa_Table: # type: ignore
"""Dummy class for pa.Table.""" """Dummy class for pa.Table."""
......
...@@ -833,6 +833,7 @@ class Booster { ...@@ -833,6 +833,7 @@ class Booster {
// explicitly declare symbols from LightGBM namespace // explicitly declare symbols from LightGBM namespace
using LightGBM::AllgatherFunction; using LightGBM::AllgatherFunction;
using LightGBM::ArrowChunkedArray;
using LightGBM::ArrowTable; using LightGBM::ArrowTable;
using LightGBM::Booster; using LightGBM::Booster;
using LightGBM::Common::CheckElementsIntervalClosed; using LightGBM::Common::CheckElementsIntervalClosed;
...@@ -1780,6 +1781,21 @@ int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -1780,6 +1781,21 @@ int LGBM_DatasetSetField(DatasetHandle handle,
API_END(); API_END();
} }
int LGBM_DatasetSetFieldFromArrow(DatasetHandle handle,
const char* field_name,
int64_t n_chunks,
const ArrowArray* chunks,
const ArrowSchema* schema) {
API_BEGIN();
auto dataset = reinterpret_cast<Dataset*>(handle);
ArrowChunkedArray ca(n_chunks, chunks, schema);
auto is_success = dataset->SetFieldFromArrow(field_name, ca);
if (!is_success) {
Log::Fatal("Input field is not supported");
}
API_END();
}
int LGBM_DatasetGetField(DatasetHandle handle, int LGBM_DatasetGetField(DatasetHandle handle,
const char* field_name, const char* field_name,
int* out_len, int* out_len,
......
...@@ -897,6 +897,17 @@ void Dataset::CopySubrow(const Dataset* fullset, ...@@ -897,6 +897,17 @@ void Dataset::CopySubrow(const Dataset* fullset,
#endif // USE_CUDA #endif // USE_CUDA
} }
bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray &ca) {
std::string name(field_name);
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
metadata_.SetLabel(ca);
} else {
return false;
}
return true;
}
bool Dataset::SetFloatField(const char* field_name, const float* field_data, bool Dataset::SetFloatField(const char* field_name, const float* field_data,
data_size_t num_element) { data_size_t num_element) {
std::string name(field_name); std::string name(field_name);
......
...@@ -403,27 +403,39 @@ void Metadata::InsertInitScores(const double* init_scores, data_size_t start_ind ...@@ -403,27 +403,39 @@ void Metadata::InsertInitScores(const double* init_scores, data_size_t start_ind
// CUDA is handled after all insertions are complete // CUDA is handled after all insertions are complete
} }
void Metadata::SetLabel(const label_t* label, data_size_t len) { template <typename It>
void Metadata::SetLabelsFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (label == nullptr) { if (num_data_ != last - first) {
Log::Fatal("label cannot be nullptr"); Log::Fatal("Length of labels differs from the length of #data");
} }
if (num_data_ != len) { if (label_.empty()) {
Log::Fatal("Length of label is not same with #data"); label_.resize(num_data_);
} }
if (label_.empty()) { label_.resize(num_data_); }
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_data_ >= 1024) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_data_ >= 1024)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = Common::AvoidInf(label[i]); label_[i] = Common::AvoidInf(first[i]);
} }
#ifdef USE_CUDA #ifdef USE_CUDA
if (cuda_metadata_ != nullptr) { if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetLabel(label_.data(), len); cuda_metadata_->SetLabel(label_.data(), label_.size());
} }
#endif // USE_CUDA #endif // USE_CUDA
} }
void Metadata::SetLabel(const label_t* label, data_size_t len) {
if (label == nullptr) {
Log::Fatal("label cannot be nullptr");
}
SetLabelsFromIterator(label, label + len);
}
void Metadata::SetLabel(const ArrowChunkedArray& array) {
SetLabelsFromIterator(array.begin<label_t>(), array.end<label_t>());
}
void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len) { void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data_size_t len) {
if (labels == nullptr) { if (labels == nullptr) {
Log::Fatal("label cannot be nullptr"); Log::Fatal("label cannot be nullptr");
......
...@@ -67,6 +67,10 @@ def dummy_dataset_params() -> Dict[str, Any]: ...@@ -67,6 +67,10 @@ def dummy_dataset_params() -> Dict[str, Any]:
} }
def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray):
assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs)
# ----------------------------------------------------------------------------------------------- # # ----------------------------------------------------------------------------------------------- #
# UNIT TESTS # # UNIT TESTS #
# ----------------------------------------------------------------------------------------------- # # ----------------------------------------------------------------------------------------------- #
...@@ -97,3 +101,45 @@ def test_dataset_construct_fuzzy( ...@@ -97,3 +101,45 @@ def test_dataset_construct_fuzzy(
arrow_dataset._dump_text(tmp_path / "arrow.txt") arrow_dataset._dump_text(tmp_path / "arrow.txt")
pandas_dataset._dump_text(tmp_path / "pandas.txt") pandas_dataset._dump_text(tmp_path / "pandas.txt")
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt") assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")
@pytest.mark.parametrize(
["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
)
@pytest.mark.parametrize(
"arrow_type",
[
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64(),
pa.float32(),
pa.float64(),
],
)
def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any):
data = generate_dummy_arrow_table()
labels = array_type(label_data, type=arrow_type)
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
dataset.construct()
expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
assert_arrays_equal(expected, dataset.get_label())
def test_dataset_construct_labels_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)
arrow_dataset = lgb.Dataset(arrow_table, label=arrow_array)
arrow_dataset.construct()
pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), label=arrow_array.to_numpy())
pandas_dataset.construct()
assert_arrays_equal(arrow_dataset.get_label(), pandas_dataset.get_label())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment