Unverified Commit 661bde10 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python][tests] refactor tests with Sequence input (#4495)

parent 5d5f4909
...@@ -633,7 +633,7 @@ class Sequence(abc.ABC): ...@@ -633,7 +633,7 @@ class Sequence(abc.ABC):
batch_size = 4096 # Defaults to read 4K rows in each batch. batch_size = 4096 # Defaults to read 4K rows in each batch.
@abc.abstractmethod @abc.abstractmethod
def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: def __getitem__(self, idx: Union[int, slice, List[int]]) -> np.ndarray:
"""Return data for given row index. """Return data for given row index.
A basic implementation should look like this: A basic implementation should look like this:
...@@ -645,20 +645,20 @@ class Sequence(abc.ABC): ...@@ -645,20 +645,20 @@ class Sequence(abc.ABC):
elif isinstance(idx, slice): elif isinstance(idx, slice):
return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)]) return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)])
elif isinstance(idx, list): elif isinstance(idx, list):
# Only required if using ``Dataset.get_data()``. # Only required if using ``Dataset.subset()``.
return np.array([self._get_one_line(i) for i in idx]) return np.array([self._get_one_line(i) for i in idx])
else: else:
raise TypeError(f"Sequence index must be integer or slice, got {type(idx).__name__}") raise TypeError(f"Sequence index must be integer, slice or list, got {type(idx).__name__}")
Parameters Parameters
---------- ----------
idx : int, slice[int] idx : int, slice[int], list[int]
Item index. Item index.
Returns Returns
------- -------
result : numpy 1-D array, numpy 2-D array result : numpy 1-D array, numpy 2-D array
1-D array if idx is int, 2-D array if idx is slice. 1-D array if idx is int, 2-D array if idx is slice or list.
""" """
raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()") raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()")
......
# coding: utf-8 # coding: utf-8
import filecmp import filecmp
import numbers import numbers
import types
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -106,6 +105,8 @@ class NumpySequence(lgb.Sequence): ...@@ -106,6 +105,8 @@ class NumpySequence(lgb.Sequence):
if not (idx.step is None or idx.step == 1): if not (idx.step is None or idx.step == 1):
raise NotImplementedError("No need to implement, caller will not set step by now") raise NotImplementedError("No need to implement, caller will not set step by now")
return self.ndarray[idx.start:idx.stop] return self.ndarray[idx.start:idx.stop]
elif isinstance(idx, list):
return self.ndarray[idx]
else: else:
raise TypeError(f"Sequence Index must be an integer/list/slice, got {type(idx).__name__}") raise TypeError(f"Sequence Index must be an integer/list/slice, got {type(idx).__name__}")
...@@ -195,29 +196,21 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq): ...@@ -195,29 +196,21 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname) assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname)
def test_sequence_get_data(): @pytest.mark.parametrize('num_seq', [1, 2])
def test_sequence_get_data(num_seq):
nrow = 20 nrow = 20
ncol = 11 ncol = 11
data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol)) data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol))
X = data[:, :-1] X = data[:, :-1]
Y = data[:, -1] Y = data[:, -1]
seqs = _create_sequence_from_ndarray(X, 2, 6) seqs = _create_sequence_from_ndarray(data=X, num_seq=num_seq, batch_size=6)
seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False) seq_ds = lgb.Dataset(seqs, label=Y, params=None, free_raw_data=False).construct()
seq_ds.construct() assert seq_ds.get_data() == seqs
assert seqs == seq_ds.get_data()
# This is a hack to add test coverage in get_data.
used_indices = [0, 5, 11, 15]
ref_data = types.SimpleNamespace()
ref_data.data = seqs
seq_ds.need_slice = True used_indices = np.random.choice(np.arange(nrow), nrow // 3, replace=False)
seq_ds.reference = ref_data subset_data = seq_ds.subset(used_indices).construct()
seq_ds.used_indices = used_indices np.testing.assert_array_equal(subset_data.get_data(), X[sorted(used_indices)])
assert (X[used_indices] == seq_ds.get_data()).all()
def test_chunked_dataset(): def test_chunked_dataset():
...@@ -339,8 +332,13 @@ def test_add_features_from_different_sources(): ...@@ -339,8 +332,13 @@ def test_add_features_from_different_sources():
n_row = 100 n_row = 100
n_col = 5 n_col = 5
X = np.random.random((n_row, n_col)) X = np.random.random((n_row, n_col))
xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X), _create_sequence_from_ndarray(X, 1, 30)] xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)]
names = [f'col_{i}' for i in range(n_col)] names = [f'col_{i}' for i in range(n_col)]
seq = _create_sequence_from_ndarray(X, 1, 30)
seq_ds = lgb.Dataset(seq, feature_name=names, free_raw_data=False).construct()
npy_list_ds = lgb.Dataset([X[:n_row // 2, :], X[n_row // 2:, :]],
feature_name=names, free_raw_data=False).construct()
immergeable_dds = [seq_ds, npy_list_ds]
for x_1 in xxs: for x_1 in xxs:
# test that method works even with free_raw_data=True # test that method works even with free_raw_data=True
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=True).construct() d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=True).construct()
...@@ -350,18 +348,14 @@ def test_add_features_from_different_sources(): ...@@ -350,18 +348,14 @@ def test_add_features_from_different_sources():
# test that method works but sets raw data to None in case of immergeable data types # test that method works but sets raw data to None in case of immergeable data types
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct() d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
d2 = lgb.Dataset([X[:n_row // 2, :], X[n_row // 2:, :]], for d2 in immergeable_dds:
feature_name=names, free_raw_data=False).construct() d1.add_features_from(d2)
d1.add_features_from(d2) assert d1.data is None
assert d1.data is None
# test that method works for different data types # test that method works for different data types
d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct() d1 = lgb.Dataset(x_1, feature_name=names, free_raw_data=False).construct()
res_feature_names = [name for name in names] res_feature_names = [name for name in names]
for idx, x_2 in enumerate(xxs, 2): for idx, x_2 in enumerate(xxs, 2):
# Dataset.get_data does not support Sequence input.
if isinstance(x_1, lgb.Sequence) or isinstance(x_2, lgb.Sequence):
continue
original_type = type(d1.get_data()) original_type = type(d1.get_data())
d2 = lgb.Dataset(x_2, feature_name=names, free_raw_data=False).construct() d2 = lgb.Dataset(x_2, feature_name=names, free_raw_data=False).construct()
d1.add_features_from(d2) d1.add_features_from(d2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment