Unverified Commit 516bde95 authored by Oliver Borchert's avatar Oliver Borchert Committed by GitHub
Browse files

[python-package] Allow to pass Arrow array as groups (#6166)

parent bc694222
...@@ -558,9 +558,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -558,9 +558,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
/*! /*!
* \brief Set vector to a content in info. * \brief Set vector to a content in info.
* \note * \note
* - \a group converts input datatype into ``int32``;
* - \a label and \a weight convert input datatype into ``float32``. * - \a label and \a weight convert input datatype into ``float32``.
* \param handle Handle of dataset * \param handle Handle of dataset
* \param field_name Field name, can be \a label, \a weight * \param field_name Field name, can be \a label, \a weight, \a group
* \param n_chunks The number of Arrow arrays passed to this function * \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays * \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays * \param schema Pointer to the schema of all Arrow arrays
......
...@@ -116,6 +116,7 @@ class Metadata { ...@@ -116,6 +116,7 @@ class Metadata {
void SetWeights(const ArrowChunkedArray& array); void SetWeights(const ArrowChunkedArray& array);
void SetQuery(const data_size_t* query, data_size_t len); void SetQuery(const data_size_t* query, data_size_t len);
void SetQuery(const ArrowChunkedArray& array);
void SetPosition(const data_size_t* position, data_size_t len); void SetPosition(const data_size_t* position, data_size_t len);
...@@ -348,6 +349,9 @@ class Metadata { ...@@ -348,6 +349,9 @@ class Metadata {
void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size); void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size);
/*! \brief Insert queries at the given index */ /*! \brief Insert queries at the given index */
void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len); void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len);
/*! \brief Set queries from pointers to the first element and the end of an iterator. */
template <typename It>
void SetQueriesFromIterator(It first, It last);
/*! \brief Filename of current data */ /*! \brief Filename of current data */
std::string data_filename_; std::string data_filename_;
/*! \brief Number of data */ /*! \brief Number of data */
......
...@@ -70,7 +70,9 @@ _LGBM_GroupType = Union[ ...@@ -70,7 +70,9 @@ _LGBM_GroupType = Union[
List[float], List[float],
List[int], List[int],
np.ndarray, np.ndarray,
pd_Series pd_Series,
pa_Array,
pa_ChunkedArray,
] ]
_LGBM_PositionType = Union[ _LGBM_PositionType = Union[
np.ndarray, np.ndarray,
...@@ -1652,7 +1654,7 @@ class Dataset: ...@@ -1652,7 +1654,7 @@ class Dataset:
If this is Dataset for validation, training data should be used as reference. If this is Dataset for validation, training data should be used as reference.
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative. Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None) group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Group/query data. Group/query data.
Only used in the learning-to-rank task. Only used in the learning-to-rank task.
sum(group) = n_samples. sum(group) = n_samples.
...@@ -2432,7 +2434,7 @@ class Dataset: ...@@ -2432,7 +2434,7 @@ class Dataset:
Label of the data. Label of the data.
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative. Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None) group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Group/query data. Group/query data.
Only used in the learning-to-rank task. Only used in the learning-to-rank task.
sum(group) = n_samples. sum(group) = n_samples.
...@@ -2889,7 +2891,7 @@ class Dataset: ...@@ -2889,7 +2891,7 @@ class Dataset:
Parameters Parameters
---------- ----------
group : list, numpy 1-D array, pandas Series or None group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
Group/query data. Group/query data.
Only used in the learning-to-rank task. Only used in the learning-to-rank task.
sum(group) = n_samples. sum(group) = n_samples.
...@@ -2903,6 +2905,7 @@ class Dataset: ...@@ -2903,6 +2905,7 @@ class Dataset:
""" """
self.group = group self.group = group
if self._handle is not None and group is not None: if self._handle is not None and group is not None:
if not _is_pyarrow_array(group):
group = _list_to_1d_numpy(group, dtype=np.int32, name='group') group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
self.set_field('group', group) self.set_field('group', group)
# original values can be modified at cpp side # original values can be modified at cpp side
...@@ -4431,7 +4434,7 @@ class Booster: ...@@ -4431,7 +4434,7 @@ class Booster:
.. versionadded:: 4.0.0 .. versionadded:: 4.0.0
group : list, numpy 1-D array, pandas Series or None, optional (default=None) group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Group/query size for ``data``. Group/query size for ``data``.
Only used in the learning-to-rank task. Only used in the learning-to-rank task.
sum(group) = n_samples. sum(group) = n_samples.
......
...@@ -904,6 +904,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray ...@@ -904,6 +904,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray
metadata_.SetLabel(ca); metadata_.SetLabel(ca);
} else if (name == std::string("weight") || name == std::string("weights")) { } else if (name == std::string("weight") || name == std::string("weights")) {
metadata_.SetWeights(ca); metadata_.SetWeights(ca);
} else if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQuery(ca);
} else { } else {
return false; return false;
} }
......
...@@ -507,30 +507,34 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da ...@@ -507,30 +507,34 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da
// CUDA is handled after all insertions are complete // CUDA is handled after all insertions are complete
} }
void Metadata::SetQuery(const data_size_t* query, data_size_t len) { template <typename It>
void Metadata::SetQueriesFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr // Clear query boundaries on empty input
if (query == nullptr || len == 0) { if (last - first == 0) {
query_boundaries_.clear(); query_boundaries_.clear();
num_queries_ = 0; num_queries_ = 0;
return; return;
} }
data_size_t sum = 0; data_size_t sum = 0;
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum)
for (data_size_t i = 0; i < len; ++i) { for (data_size_t i = 0; i < last - first; ++i) {
sum += query[i]; sum += first[i];
} }
if (num_data_ != sum) { if (num_data_ != sum) {
Log::Fatal("Sum of query counts is not same with #data"); Log::Fatal("Sum of query counts (%i) differs from the length of #data (%i)", num_data_, sum);
} }
num_queries_ = len; num_queries_ = last - first;
query_boundaries_.resize(num_queries_ + 1); query_boundaries_.resize(num_queries_ + 1);
query_boundaries_[0] = 0; query_boundaries_[0] = 0;
for (data_size_t i = 0; i < num_queries_; ++i) { for (data_size_t i = 0; i < num_queries_; ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + query[i]; query_boundaries_[i + 1] = query_boundaries_[i] + first[i];
} }
CalculateQueryWeights(); CalculateQueryWeights();
query_load_from_file_ = false; query_load_from_file_ = false;
#ifdef USE_CUDA #ifdef USE_CUDA
if (cuda_metadata_ != nullptr) { if (cuda_metadata_ != nullptr) {
if (query_weights_.size() > 0) { if (query_weights_.size() > 0) {
...@@ -543,6 +547,14 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) { ...@@ -543,6 +547,14 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
#endif // USE_CUDA #endif // USE_CUDA
} }
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
SetQueriesFromIterator(query, query + len);
}
void Metadata::SetQuery(const ArrowChunkedArray& array) {
SetQueriesFromIterator(array.begin<data_size_t>(), array.end<data_size_t>());
}
void Metadata::SetPosition(const data_size_t* positions, data_size_t len) { void Metadata::SetPosition(const data_size_t* positions, data_size_t len) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr // save to nullptr
......
# coding: utf-8 # coding: utf-8
import filecmp import filecmp
from pathlib import Path from typing import Any, Dict
from typing import Any, Callable, Dict
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
...@@ -15,6 +14,21 @@ from .utils import np_assert_array_equal ...@@ -15,6 +14,21 @@ from .utils import np_assert_array_equal
# UTILITIES # # UTILITIES #
# ----------------------------------------------------------------------------------------------- # # ----------------------------------------------------------------------------------------------- #
_INTEGER_TYPES = [
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64(),
]
_FLOAT_TYPES = [
pa.float32(),
pa.float64(),
]
def generate_simple_arrow_table() -> pa.Table: def generate_simple_arrow_table() -> pa.Table:
columns = [ columns = [
...@@ -85,9 +99,7 @@ def dummy_dataset_params() -> Dict[str, Any]: ...@@ -85,9 +99,7 @@ def dummy_dataset_params() -> Dict[str, Any]:
(lambda: generate_random_arrow_table(100, 10000, 43), {}), (lambda: generate_random_arrow_table(100, 10000, 43), {}),
], ],
) )
def test_dataset_construct_fuzzy( def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):
tmp_path: Path, arrow_table_fn: Callable[[], pa.Table], dataset_params: Dict[str, Any]
):
arrow_table = arrow_table_fn() arrow_table = arrow_table_fn()
arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params) arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params)
...@@ -108,17 +120,23 @@ def test_dataset_construct_fields_fuzzy(): ...@@ -108,17 +120,23 @@ def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42) arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42) arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42) arrow_weights = generate_random_arrow_array(1000, 42)
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())
arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights) arrow_dataset = lgb.Dataset(
arrow_table, label=arrow_labels, weight=arrow_weights, group=arrow_groups
)
arrow_dataset.construct() arrow_dataset.construct()
pandas_dataset = lgb.Dataset( pandas_dataset = lgb.Dataset(
arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy() arrow_table.to_pandas(),
label=arrow_labels.to_numpy(),
weight=arrow_weights.to_numpy(),
group=arrow_groups.to_numpy(),
) )
pandas_dataset.construct() pandas_dataset.construct()
# Check for equality # Check for equality
for field in ("label", "weight"): for field in ("label", "weight", "group"):
np_assert_array_equal( np_assert_array_equal(
arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
) )
...@@ -133,22 +151,8 @@ def test_dataset_construct_fields_fuzzy(): ...@@ -133,22 +151,8 @@ def test_dataset_construct_fields_fuzzy():
["array_type", "label_data"], ["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])], [(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES)
"arrow_type", def test_dataset_construct_labels(array_type, label_data, arrow_type):
[
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64(),
pa.float32(),
pa.float64(),
],
)
def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any):
data = generate_dummy_arrow_table() data = generate_dummy_arrow_table()
labels = array_type(label_data, type=arrow_type) labels = array_type(label_data, type=arrow_type)
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params()) dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
...@@ -175,7 +179,7 @@ def test_dataset_construct_weights_none(): ...@@ -175,7 +179,7 @@ def test_dataset_construct_weights_none():
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])], [(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
) )
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()]) @pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any): def test_dataset_construct_weights(array_type, weight_data, arrow_type):
data = generate_dummy_arrow_table() data = generate_dummy_arrow_table()
weights = array_type(weight_data, type=arrow_type) weights = array_type(weight_data, type=arrow_type)
dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params()) dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
...@@ -183,3 +187,26 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type ...@@ -183,3 +187,26 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type
expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32) expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
np_assert_array_equal(expected, dataset.get_weight(), strict=True) np_assert_array_equal(expected, dataset.get_weight(), strict=True)
# -------------------------------------------- GROUPS ------------------------------------------- #
@pytest.mark.parametrize(
["array_type", "group_data"],
[
(pa.array, [2, 3]),
(pa.chunked_array, [[2], [3]]),
(pa.chunked_array, [[], [2, 3]]),
(pa.chunked_array, [[2], [], [3], []]),
],
)
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES)
def test_dataset_construct_groups(array_type, group_data, arrow_type):
data = generate_dummy_arrow_table()
groups = array_type(group_data, type=arrow_type)
dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params())
dataset.construct()
expected = np.array([0, 2, 5], dtype=np.int32)
np_assert_array_equal(expected, dataset.get_field("group"), strict=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment