Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
661bde10
Unverified
Commit
661bde10
authored
Jul 31, 2021
by
Nikita Titov
Committed by
GitHub
Jul 31, 2021
Browse files
[python][tests] refactor tests with Sequence input (#4495)
parent
5d5f4909
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
30 deletions
+24
-30
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+5
-5
tests/python_package_test/test_basic.py
tests/python_package_test/test_basic.py
+19
-25
No files found.
python-package/lightgbm/basic.py
View file @
661bde10
...
...
@@ -633,7 +633,7 @@ class Sequence(abc.ABC):
batch_size
=
4096
# Defaults to read 4K rows in each batch.
@
abc
.
abstractmethod
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
])
->
np
.
ndarray
:
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
,
List
[
int
]
])
->
np
.
ndarray
:
"""Return data for given row index.
A basic implementation should look like this:
...
...
@@ -645,20 +645,20 @@ class Sequence(abc.ABC):
elif isinstance(idx, slice):
return np.stack([self._get_one_line(i) for i in range(idx.start, idx.stop)])
elif isinstance(idx, list):
# Only required if using ``Dataset.
get_data
()``.
# Only required if using ``Dataset.
subset
()``.
return np.array([self._get_one_line(i) for i in idx])
else:
raise TypeError(f"Sequence index must be integer
or
slice, got {type(idx).__name__}")
raise TypeError(f"Sequence index must be integer
,
slice
or list
, got {type(idx).__name__}")
Parameters
----------
idx : int, slice[int]
idx : int, slice[int]
, list[int]
Item index.
Returns
-------
result : numpy 1-D array, numpy 2-D array
1-D array if idx is int, 2-D array if idx is slice.
1-D array if idx is int, 2-D array if idx is slice
or list
.
"""
raise
NotImplementedError
(
"Sub-classes of lightgbm.Sequence must implement __getitem__()"
)
...
...
tests/python_package_test/test_basic.py
View file @
661bde10
# coding: utf-8
import
filecmp
import
numbers
import
types
from
pathlib
import
Path
import
numpy
as
np
...
...
@@ -106,6 +105,8 @@ class NumpySequence(lgb.Sequence):
if
not
(
idx
.
step
is
None
or
idx
.
step
==
1
):
raise
NotImplementedError
(
"No need to implement, caller will not set step by now"
)
return
self
.
ndarray
[
idx
.
start
:
idx
.
stop
]
elif
isinstance
(
idx
,
list
):
return
self
.
ndarray
[
idx
]
else
:
raise
TypeError
(
f
"Sequence Index must be an integer/list/slice, got
{
type
(
idx
).
__name__
}
"
)
...
...
@@ -195,29 +196,21 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq):
assert
filecmp
.
cmp
(
valid_npy_bin_fname
,
valid_seq2_bin_fname
)
def
test_sequence_get_data
():
@
pytest
.
mark
.
parametrize
(
'num_seq'
,
[
1
,
2
])
def
test_sequence_get_data
(
num_seq
):
nrow
=
20
ncol
=
11
data
=
np
.
arange
(
nrow
*
ncol
,
dtype
=
np
.
float64
).
reshape
((
nrow
,
ncol
))
X
=
data
[:,
:
-
1
]
Y
=
data
[:,
-
1
]
seqs
=
_create_sequence_from_ndarray
(
X
,
2
,
6
)
seq_ds
=
lgb
.
Dataset
(
seqs
,
label
=
Y
,
params
=
None
,
free_raw_data
=
False
)
seq_ds
.
construct
()
assert
seqs
==
seq_ds
.
get_data
()
seqs
=
_create_sequence_from_ndarray
(
data
=
X
,
num_seq
=
num_seq
,
batch_size
=
6
)
seq_ds
=
lgb
.
Dataset
(
seqs
,
label
=
Y
,
params
=
None
,
free_raw_data
=
False
).
construct
()
assert
seq_ds
.
get_data
()
==
seqs
# This is a hack to add test coverage in get_data.
used_indices
=
[
0
,
5
,
11
,
15
]
ref_data
=
types
.
SimpleNamespace
()
ref_data
.
data
=
seqs
seq_ds
.
need_slice
=
True
seq_ds
.
reference
=
ref_data
seq_ds
.
used_indices
=
used_indices
assert
(
X
[
used_indices
]
==
seq_ds
.
get_data
()).
all
()
used_indices
=
np
.
random
.
choice
(
np
.
arange
(
nrow
),
nrow
//
3
,
replace
=
False
)
subset_data
=
seq_ds
.
subset
(
used_indices
).
construct
()
np
.
testing
.
assert_array_equal
(
subset_data
.
get_data
(),
X
[
sorted
(
used_indices
)])
def
test_chunked_dataset
():
...
...
@@ -339,8 +332,13 @@ def test_add_features_from_different_sources():
n_row
=
100
n_col
=
5
X
=
np
.
random
.
random
((
n_row
,
n_col
))
xxs
=
[
X
,
sparse
.
csr_matrix
(
X
),
pd
.
DataFrame
(
X
)
,
_create_sequence_from_ndarray
(
X
,
1
,
30
)
]
xxs
=
[
X
,
sparse
.
csr_matrix
(
X
),
pd
.
DataFrame
(
X
)]
names
=
[
f
'col_
{
i
}
'
for
i
in
range
(
n_col
)]
seq
=
_create_sequence_from_ndarray
(
X
,
1
,
30
)
seq_ds
=
lgb
.
Dataset
(
seq
,
feature_name
=
names
,
free_raw_data
=
False
).
construct
()
npy_list_ds
=
lgb
.
Dataset
([
X
[:
n_row
//
2
,
:],
X
[
n_row
//
2
:,
:]],
feature_name
=
names
,
free_raw_data
=
False
).
construct
()
immergeable_dds
=
[
seq_ds
,
npy_list_ds
]
for
x_1
in
xxs
:
# test that method works even with free_raw_data=True
d1
=
lgb
.
Dataset
(
x_1
,
feature_name
=
names
,
free_raw_data
=
True
).
construct
()
...
...
@@ -350,8 +348,7 @@ def test_add_features_from_different_sources():
# test that method works but sets raw data to None in case of immergeable data types
d1
=
lgb
.
Dataset
(
x_1
,
feature_name
=
names
,
free_raw_data
=
False
).
construct
()
d2
=
lgb
.
Dataset
([
X
[:
n_row
//
2
,
:],
X
[
n_row
//
2
:,
:]],
feature_name
=
names
,
free_raw_data
=
False
).
construct
()
for
d2
in
immergeable_dds
:
d1
.
add_features_from
(
d2
)
assert
d1
.
data
is
None
...
...
@@ -359,9 +356,6 @@ def test_add_features_from_different_sources():
d1
=
lgb
.
Dataset
(
x_1
,
feature_name
=
names
,
free_raw_data
=
False
).
construct
()
res_feature_names
=
[
name
for
name
in
names
]
for
idx
,
x_2
in
enumerate
(
xxs
,
2
):
# Dataset.get_data does not support Sequence input.
if
isinstance
(
x_1
,
lgb
.
Sequence
)
or
isinstance
(
x_2
,
lgb
.
Sequence
):
continue
original_type
=
type
(
d1
.
get_data
())
d2
=
lgb
.
Dataset
(
x_2
,
feature_name
=
names
,
free_raw_data
=
False
).
construct
()
d1
.
add_features_from
(
d2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment