Unverified Commit 661bde10 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python][tests] refactor tests with Sequence input (#4495)

parent 5d5f4909
......@@ -633,7 +633,7 @@ class Sequence(abc.ABC):
batch_size = 4096 # Defaults to read 4K rows in each batch.
@abc.abstractmethod
def __getitem__(self, idx: Union[int, slice]) -> np.ndarray:
def __getitem__(self, idx: Union[int, slice, List[int]]) -> np.ndarray:
"""Return data for given row index.
A basic implementation should look like this:
......@@ -645,20 +645,20 @@ class Sequence(abc.ABC):
elif isinstance(idx, slice):
return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)])
elif isinstance(idx, list):
# Only required if using ``Dataset.get_data()``.
# Only required if using ``Dataset.subset()``.
return np.array([self._get_one_line(i) for i in idx])
else:
raise TypeError(f"Sequence index must be integer or slice, got {type(idx).__name__}")
raise TypeError(f"Sequence index must be integer, slice or list, got {type(idx).__name__}")
Parameters
----------
idx : int, slice[int]
idx : int, slice[int], list[int]
Item index.
Returns
-------
result : numpy 1-D array, numpy 2-D array
1-D array if idx is int, 2-D array if idx is slice.
1-D array if idx is int, 2-D array if idx is slice or list.
"""
raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()")
......
# coding: utf-8
import filecmp
import numbers
import types
from pathlib import Path
import numpy as np
......@@ -106,6 +105,8 @@ class NumpySequence(lgb.Sequence):
if not (idx.step is None or idx.step == 1):
raise NotImplementedError("No need to implement, caller will not set step by now")
return self.ndarray[idx.start:idx.stop]
elif isinstance(idx, list):
return self.ndarray[idx]
else:
raise TypeError(f"Sequence Index must be an integer/list/slice, got {type(idx).__name__}")
......@@ -195,29 +196,21 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname)
def test_sequence_get_data():
@pytest.mark.parametrize('num_seq', [1, 2])
def test_sequence_get_data(num_seq):
nrow = 20
ncol = 11
data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol))
X = data[:, :-1]
Y = data[:, -1]
seqs = _create_sequence_from_ndarray(X, 2, 6)
seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False)
seq_ds.construct()
assert seqs == seq_ds.get_data()
seqs = _create_sequence_from_ndarray(data=X, num_seq=num_seq, batch_size=6)
seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False).construct()
assert seq_ds.get_data() == seqs
# This is a hack to add test coverage in get_data.
used_indices = [0, 5, 11, 15]
ref_data = types.SimpleNamespace()
ref_data.data = seqs
seq_ds.need_slice = True
seq_ds.reference = ref_data
seq_ds.used_indices = used_indices
assert (X[used_indices] == seq_ds.get_data()).all()
used_indices = np.random.choice(np.arange(nrow), nrow // 3, replace=False)
subset_data = seq_ds.subset(used_indices).construct()
np.testing.assert_array_equal(subset_data.get_data(), X[sorted(used_indices)])
def test_chunked_dataset():
......@@ -339,8 +332,13 @@ def test_add_features_from_different_sources():
n_row = 100
n_col = 5
X = np.random.random((n_row, n_col))
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X), _create_sequence_from_ndarray(X, 1, 30)]
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)]
names = [f'col_{i}' for i in range(n_col)]
seq = _create_sequence_from_ndarray(X, 1, 30)
seq_ds = lgb.Dataset(seq, feature_name=names, free_raw_data=False).construct()
npy_list_ds = lgb.Dataset([X[:n_row // 2, :], X[n_row // 2:, :]],
feature_name=names, free_raw_data=False).construct()
immergeable_dds = [seq_ds, npy_list_ds]
for x_1 in xxs:
# test that method works even with free_raw_data=True
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=True).construct()
......@@ -350,8 +348,7 @@ def test_add_features_from_different_sources():
# test that method works but sets raw data to None in case of immergeable data types
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
d2 = lgb.Dataset([X[:n_row // 2, :], X[n_row // 2:, :]],
feature_name=names, free_raw_data=False).construct()
for d2 in immergeable_dds:
d1.add_features_from(d2)
assert d1.data is None
......@@ -359,9 +356,6 @@ def test_add_features_from_different_sources():
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
res_feature_names = [name for name in names]
for idx, x_2 in enumerate(xxs, 2):
# Dataset.get_data does not support Sequence input.
if isinstance(x_1, lgb.Sequence) or isinstance(x_2, lgb.Sequence):
continue
original_type = type(d1.get_data())
d2 = lgb.Dataset(x_2, feature_name=names, free_raw_data=False).construct()
d1.add_features_from(d2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment