Unverified Commit deb70773 authored by Oliver Borchert's avatar Oliver Borchert Committed by GitHub
Browse files

[python-package] Allow to pass Arrow array as weights (#6164)

parent 501e6e62
...@@ -558,9 +558,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, ...@@ -558,9 +558,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
/*! /*!
* \brief Set vector to a content in info. * \brief Set vector to a content in info.
* \note * \note
* - \a label convert input datatype into ``float32``. * - \a label and \a weight convert input datatype into ``float32``.
* \param handle Handle of dataset * \param handle Handle of dataset
* \param field_name Field name, can be \a label * \param field_name Field name, can be \a label, \a weight
* \param n_chunks The number of Arrow arrays passed to this function * \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays * \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays * \param schema Pointer to the schema of all Arrow arrays
......
...@@ -113,6 +113,7 @@ class Metadata { ...@@ -113,6 +113,7 @@ class Metadata {
void SetLabel(const ArrowChunkedArray& array); void SetLabel(const ArrowChunkedArray& array);
void SetWeights(const label_t* weights, data_size_t len); void SetWeights(const label_t* weights, data_size_t len);
void SetWeights(const ArrowChunkedArray& array);
void SetQuery(const data_size_t* query, data_size_t len); void SetQuery(const data_size_t* query, data_size_t len);
...@@ -340,6 +341,9 @@ class Metadata { ...@@ -340,6 +341,9 @@ class Metadata {
void SetLabelsFromIterator(It first, It last); void SetLabelsFromIterator(It first, It last);
/*! \brief Insert weights at the given index */ /*! \brief Insert weights at the given index */
void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len); void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len);
/*! \brief Set weights from pointers to the first element and the end of an iterator. */
template <typename It>
void SetWeightsFromIterator(It first, It last);
/*! \brief Insert initial scores at the given index */ /*! \brief Insert initial scores at the given index */
void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size); void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size);
/*! \brief Insert queries at the given index */ /*! \brief Insert queries at the given index */
......
...@@ -19,7 +19,8 @@ import numpy as np ...@@ -19,7 +19,8 @@ import numpy as np
import scipy.sparse import scipy.sparse
from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat, from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat,
dt_DataTable, pa_Array, pa_ChunkedArray, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series) dt_DataTable, pa_Array, pa_ChunkedArray, pa_compute, pa_Table, pd_CategoricalDtype, pd_DataFrame,
pd_Series)
from .libpath import find_lib_path from .libpath import find_lib_path
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -115,7 +116,9 @@ _LGBM_WeightType = Union[ ...@@ -115,7 +116,9 @@ _LGBM_WeightType = Union[
List[float], List[float],
List[int], List[int],
np.ndarray, np.ndarray,
pd_Series pd_Series,
pa_Array,
pa_ChunkedArray,
] ]
ZERO_THRESHOLD = 1e-35 ZERO_THRESHOLD = 1e-35
...@@ -1635,7 +1638,7 @@ class Dataset: ...@@ -1635,7 +1638,7 @@ class Dataset:
Label of the data. Label of the data.
reference : Dataset or None, optional (default=None) reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference. If this is Dataset for validation, training data should be used as reference.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None) weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative. Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None) group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query data. Group/query data.
...@@ -2415,7 +2418,7 @@ class Dataset: ...@@ -2415,7 +2418,7 @@ class Dataset:
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data. Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None) weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative. Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None) group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query data. Group/query data.
...@@ -2830,7 +2833,7 @@ class Dataset: ...@@ -2830,7 +2833,7 @@ class Dataset:
Parameters Parameters
---------- ----------
weight : list, numpy 1-D array, pandas Series or None weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
Weight to be set for each data point. Weights should be non-negative. Weight to be set for each data point. Weights should be non-negative.
Returns Returns
...@@ -2838,11 +2841,19 @@ class Dataset: ...@@ -2838,11 +2841,19 @@ class Dataset:
self : Dataset self : Dataset
Dataset with set weight. Dataset with set weight.
""" """
if weight is not None and np.all(weight == 1): # Check if the weight contains values other than one
weight = None if weight is not None:
if _is_pyarrow_array(weight):
if pa_compute.all(pa_compute.equal(weight, 1)).as_py():
weight = None
elif np.all(weight == 1):
weight = None
self.weight = weight self.weight = weight
# Set field
if self._handle is not None and weight is not None: if self._handle is not None and weight is not None:
weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight') if not _is_pyarrow_array(weight):
weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight')
self.set_field('weight', weight) self.set_field('weight', weight)
self.weight = self.get_field('weight') # original values can be modified at cpp side self.weight = self.get_field('weight') # original values can be modified at cpp side
return self return self
...@@ -4414,7 +4425,7 @@ class Booster: ...@@ -4414,7 +4425,7 @@ class Booster:
.. versionadded:: 4.0.0 .. versionadded:: 4.0.0
weight : list, numpy 1-D array, pandas Series or None, optional (default=None) weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each ``data`` instance. Weights should be non-negative. Weight for each ``data`` instance. Weights should be non-negative.
.. versionadded:: 4.0.0 .. versionadded:: 4.0.0
......
...@@ -197,6 +197,7 @@ except ImportError: ...@@ -197,6 +197,7 @@ except ImportError:
"""pyarrow""" """pyarrow"""
try: try:
import pyarrow.compute as pa_compute
from pyarrow import Array as pa_Array from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table from pyarrow import Table as pa_Table
...@@ -236,6 +237,12 @@ except ImportError: ...@@ -236,6 +237,12 @@ except ImportError:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
class pa_compute: # type: ignore
"""Dummy class for pyarrow.compute."""
all = None
equal = None
arrow_is_integer = None arrow_is_integer = None
arrow_is_floating = None arrow_is_floating = None
......
...@@ -902,6 +902,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray ...@@ -902,6 +902,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray
name = Common::Trim(name); name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) { if (name == std::string("label") || name == std::string("target")) {
metadata_.SetLabel(ca); metadata_.SetLabel(ca);
} else if (name == std::string("weight") || name == std::string("weights")) {
metadata_.SetWeights(ca);
} else { } else {
return false; return false;
} }
......
...@@ -450,33 +450,45 @@ void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data ...@@ -450,33 +450,45 @@ void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data
// CUDA is handled after all insertions are complete // CUDA is handled after all insertions are complete
} }
void Metadata::SetWeights(const label_t* weights, data_size_t len) { template <typename It>
void Metadata::SetWeightsFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr // Clear weights on empty input
if (weights == nullptr || len == 0) { if (last - first == 0) {
weights_.clear(); weights_.clear();
num_weights_ = 0; num_weights_ = 0;
return; return;
} }
if (num_data_ != len) { if (num_data_ != last - first) {
Log::Fatal("Length of weights is not same with #data"); Log::Fatal("Length of weights differs from the length of #data");
}
if (weights_.empty()) {
weights_.resize(num_data_);
} }
if (weights_.empty()) { weights_.resize(num_data_); }
num_weights_ = num_data_; num_weights_ = num_data_;
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_weights_ >= 1024) #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_weights_ >= 1024)
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = Common::AvoidInf(weights[i]); weights_[i] = Common::AvoidInf(first[i]);
} }
CalculateQueryWeights(); CalculateQueryWeights();
weight_load_from_file_ = false; weight_load_from_file_ = false;
#ifdef USE_CUDA #ifdef USE_CUDA
if (cuda_metadata_ != nullptr) { if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetWeights(weights_.data(), len); cuda_metadata_->SetWeights(weights_.data(), weights_.size());
} }
#endif // USE_CUDA #endif // USE_CUDA
} }
void Metadata::SetWeights(const label_t* weights, data_size_t len) {
SetWeightsFromIterator(weights, weights + len);
}
void Metadata::SetWeights(const ArrowChunkedArray& array) {
SetWeightsFromIterator(array.begin<label_t>(), array.end<label_t>());
}
void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len) { void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len) {
if (!weights) { if (!weights) {
Log::Fatal("Passed null weights"); Log::Fatal("Passed null weights");
......
...@@ -9,6 +9,8 @@ import pytest ...@@ -9,6 +9,8 @@ import pytest
import lightgbm as lgb import lightgbm as lgb
from .utils import np_assert_array_equal
# ----------------------------------------------------------------------------------------------- # # ----------------------------------------------------------------------------------------------- #
# UTILITIES # # UTILITIES #
# ----------------------------------------------------------------------------------------------- # # ----------------------------------------------------------------------------------------------- #
...@@ -67,10 +69,6 @@ def dummy_dataset_params() -> Dict[str, Any]: ...@@ -67,10 +69,6 @@ def dummy_dataset_params() -> Dict[str, Any]:
} }
def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray):
assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs)
# ----------------------------------------------------------------------------------------------- # # ----------------------------------------------------------------------------------------------- #
# UNIT TESTS # # UNIT TESTS #
# ----------------------------------------------------------------------------------------------- # # ----------------------------------------------------------------------------------------------- #
...@@ -103,6 +101,34 @@ def test_dataset_construct_fuzzy( ...@@ -103,6 +101,34 @@ def test_dataset_construct_fuzzy(
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt") assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")
# -------------------------------------------- FIELDS ------------------------------------------- #
def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42)
arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights)
arrow_dataset.construct()
pandas_dataset = lgb.Dataset(
arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy()
)
pandas_dataset.construct()
# Check for equality
for field in ("label", "weight"):
np_assert_array_equal(
arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
)
np_assert_array_equal(arrow_dataset.get_label(), pandas_dataset.get_label(), strict=True)
np_assert_array_equal(arrow_dataset.get_weight(), pandas_dataset.get_weight(), strict=True)
# -------------------------------------------- LABELS ------------------------------------------- #
@pytest.mark.parametrize( @pytest.mark.parametrize(
["array_type", "label_data"], ["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])], [(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
...@@ -129,17 +155,31 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: ...@@ -129,17 +155,31 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type:
dataset.construct() dataset.construct()
expected = np.array([0, 1, 0, 0, 1], dtype=np.float32) expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
assert_arrays_equal(expected, dataset.get_label()) np_assert_array_equal(expected, dataset.get_label(), strict=True)
def test_dataset_construct_labels_fuzzy(): # ------------------------------------------- WEIGHTS ------------------------------------------- #
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)
arrow_dataset = lgb.Dataset(arrow_table, label=arrow_array)
arrow_dataset.construct()
pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), label=arrow_array.to_numpy()) def test_dataset_construct_weights_none():
pandas_dataset.construct() data = generate_dummy_arrow_table()
weight = pa.array([1, 1, 1, 1, 1])
dataset = lgb.Dataset(data, weight=weight, params=dummy_dataset_params())
dataset.construct()
assert dataset.get_weight() is None
assert dataset.get_field("weight") is None
@pytest.mark.parametrize(
["array_type", "weight_data"],
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
)
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any):
data = generate_dummy_arrow_table()
weights = array_type(weight_data, type=arrow_type)
dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
dataset.construct()
assert_arrays_equal(arrow_dataset.get_label(), pandas_dataset.get_label()) expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
np_assert_array_equal(expected, dataset.get_weight(), strict=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment