Unverified Commit 1d21d1ad authored by Chen Yufei's avatar Chen Yufei Committed by GitHub
Browse files

[python] support Dataset.get_data for Sequence input. (#4472)



* [python] support Dataset.get_data for Sequence input.

* Tweaks according to review comments.

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Add test cases.

* fix import order in test_basic.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 2370961a
...@@ -641,9 +641,12 @@ class Sequence(abc.ABC): ...@@ -641,9 +641,12 @@ class Sequence(abc.ABC):
.. code-block:: python .. code-block:: python
if isinstance(idx, numbers.Integral): if isinstance(idx, numbers.Integral):
return self.__get_one_line__(idx) return self._get_one_line(idx)
elif isinstance(idx, slice): elif isinstance(idx, slice):
return np.stack(self.__get_one_line__(i) for i in range(idx.start, idx.stop)) return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)])
elif isinstance(idx, list):
# Only required if using ``Dataset.get_data()``.
return np.array([self._get_one_line(i) for i in idx])
else: else:
raise TypeError(f"Sequence index must be integer or slice, got {type(idx).__name__}") raise TypeError(f"Sequence index must be integer or slice, got {type(idx).__name__}")
...@@ -1515,7 +1518,8 @@ class Dataset: ...@@ -1515,7 +1518,8 @@ class Dataset:
# set feature names # set feature names
return self.set_feature_name(feature_name) return self.set_feature_name(feature_name)
def __yield_row_from(self, seqs: List[Sequence], indices: Iterable[int]): @staticmethod
def _yield_row_from_seqlist(seqs: List[Sequence], indices: Iterable[int]):
offset = 0 offset = 0
seq_id = 0 seq_id = 0
seq = seqs[seq_id] seq = seqs[seq_id]
...@@ -1541,7 +1545,7 @@ class Dataset: ...@@ -1541,7 +1545,7 @@ class Dataset:
indices = self._create_sample_indices(total_nrow) indices = self._create_sample_indices(total_nrow)
# Select sampled rows, transpose to column order. # Select sampled rows, transpose to column order.
sampled = np.array([row for row in self.__yield_row_from(seqs, indices)]) sampled = np.array([row for row in self._yield_row_from_seqlist(seqs, indices)])
sampled = sampled.T sampled = sampled.T
filtered = [] filtered = []
...@@ -2236,7 +2240,7 @@ class Dataset: ...@@ -2236,7 +2240,7 @@ class Dataset:
Returns Returns
------- -------
data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, list of numpy arrays or None data : string, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays or None
Raw data used in the Dataset construction. Raw data used in the Dataset construction.
""" """
if self.handle is None: if self.handle is None:
...@@ -2250,6 +2254,10 @@ class Dataset: ...@@ -2250,6 +2254,10 @@ class Dataset:
self.data = self.data.iloc[self.used_indices].copy() self.data = self.data.iloc[self.used_indices].copy()
elif isinstance(self.data, dt_DataTable): elif isinstance(self.data, dt_DataTable):
self.data = self.data[self.used_indices, :] self.data = self.data[self.used_indices, :]
elif isinstance(self.data, Sequence):
self.data = self.data[self.used_indices]
elif isinstance(self.data, list) and len(self.data) > 0 and all(isinstance(x, Sequence) for x in self.data):
self.data = np.array([row for row in self._yield_row_from_seqlist(self.data, self.used_indices)])
else: else:
_log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n" _log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n"
"Returning original raw data") "Returning original raw data")
......
# coding: utf-8 # coding: utf-8
import filecmp import filecmp
import numbers import numbers
import types
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -194,6 +195,31 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq): ...@@ -194,6 +195,31 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname) assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname)
def test_sequence_get_data():
nrow = 20
ncol = 11
data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol))
X = data[:, :-1]
Y = data[:, -1]
seqs = _create_sequence_from_ndarray(X, 2, 6)
seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False)
seq_ds.construct()
assert seqs == seq_ds.get_data()
# This is a hack to add test coverage in get_data.
used_indices = [0, 5, 11, 15]
ref_data = types.SimpleNamespace()
ref_data.data = seqs
seq_ds.need_slice = True
seq_ds.reference = ref_data
seq_ds.used_indices = used_indices
assert (X[used_indices] == seq_ds.get_data()).all()
def test_chunked_dataset(): def test_chunked_dataset():
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1,
random_state=2) random_state=2)
...@@ -313,7 +339,7 @@ def test_add_features_from_different_sources(): ...@@ -313,7 +339,7 @@ def test_add_features_from_different_sources():
n_row = 100 n_row = 100
n_col = 5 n_col = 5
X = np.random.random((n_row, n_col)) X = np.random.random((n_row, n_col))
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)] xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X), _create_sequence_from_ndarray(X, 1, 30)]
names = [f'col_{i}' for i in range(n_col)] names = [f'col_{i}' for i in range(n_col)]
for x_1 in xxs: for x_1 in xxs:
# test that method works even with free_raw_data=True # test that method works even with free_raw_data=True
...@@ -333,6 +359,9 @@ def test_add_features_from_different_sources(): ...@@ -333,6 +359,9 @@ def test_add_features_from_different_sources():
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct() d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
res_feature_names = [name for name in names] res_feature_names = [name for name in names]
for idx, x_2 in enumerate(xxs, 2): for idx, x_2 in enumerate(xxs, 2):
# Dataset.get_data does not support Sequence input.
if isinstance(x_1, lgb.Sequence) or isinstance(x_2, lgb.Sequence):
continue
original_type = type(d1.get_data()) original_type = type(d1.get_data())
d2 = lgb.Dataset(x_2, feature_name=names, free_raw_data=False).construct() d2 = lgb.Dataset(x_2, feature_name=names, free_raw_data=False).construct()
d1.add_features_from(d2) d1.add_features_from(d2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment