test_arrow.py 9.75 KB
Newer Older
1
2
# coding: utf-8
import filecmp
3
from typing import Any, Dict
4
5
6
7
8
9
10

import numpy as np
import pyarrow as pa
import pytest

import lightgbm as lgb

11
12
from .utils import np_assert_array_equal

13
14
15
16
# ----------------------------------------------------------------------------------------------- #
#                                            UTILITIES                                            #
# ----------------------------------------------------------------------------------------------- #

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
_INTEGER_TYPES = [
    pa.int8(),
    pa.int16(),
    pa.int32(),
    pa.int64(),
    pa.uint8(),
    pa.uint16(),
    pa.uint32(),
    pa.uint64(),
]
_FLOAT_TYPES = [
    pa.float32(),
    pa.float64(),
]

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

def generate_simple_arrow_table() -> pa.Table:
    columns = [
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint8()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int8()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint16()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int16()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint32()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int32()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint64()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int64()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.float32()),
        pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.float64()),
    ]
    return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))])


49
50
51
52
53
54
55
56
57
58
def generate_nullable_arrow_table() -> pa.Table:
    columns = [
        pa.chunked_array([[1, None, 3, 4, 5]], type=pa.float32()),
        pa.chunked_array([[None, 2, 3, 4, 5]], type=pa.float32()),
        pa.chunked_array([[1, 2, 3, 4, None]], type=pa.float32()),
        pa.chunked_array([[None, None, None, None, None]], type=pa.float32()),
    ]
    return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))])


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def generate_dummy_arrow_table() -> pa.Table:
    col1 = pa.chunked_array([[1, 2, 3], [4, 5]], type=pa.uint8())
    col2 = pa.chunked_array([[0.5, 0.6], [0.1, 0.8, 1.5]], type=pa.float32())
    return pa.Table.from_arrays([col1, col2], names=["a", "b"])


def generate_random_arrow_table(num_columns: int, num_datapoints: int, seed: int) -> pa.Table:
    columns = [generate_random_arrow_array(num_datapoints, seed + i) for i in range(num_columns)]
    names = [f"col_{i}" for i in range(num_columns)]
    return pa.Table.from_arrays(columns, names=names)


def generate_random_arrow_array(num_datapoints: int, seed: int) -> pa.ChunkedArray:
    generator = np.random.default_rng(seed)
    data = generator.standard_normal(num_datapoints)

    # Set random nulls
    indices = generator.choice(len(data), size=num_datapoints // 10)
    data[indices] = None

    # Split data into <=2 random chunks
    split_points = np.sort(generator.choice(np.arange(1, num_datapoints), 2, replace=False))
    split_points = np.concatenate([[0], split_points, [num_datapoints]])
    chunks = [data[split_points[i] : split_points[i + 1]] for i in range(len(split_points) - 1)]
    chunks = [chunk for chunk in chunks if len(chunk) > 0]

    # Turn chunks into array
    return pa.chunked_array([data], type=pa.float32())


def dummy_dataset_params() -> Dict[str, Any]:
    return {
        "min_data_in_bin": 1,
        "min_data_in_leaf": 1,
    }


# ----------------------------------------------------------------------------------------------- #
#                                            UNIT TESTS                                           #
# ----------------------------------------------------------------------------------------------- #

# ------------------------------------------- DATASET ------------------------------------------- #


@pytest.mark.parametrize(
    ("arrow_table_fn", "dataset_params"),
    [  # Use lambda functions here to minimize memory consumption
        (lambda: generate_simple_arrow_table(), dummy_dataset_params()),
        (lambda: generate_dummy_arrow_table(), dummy_dataset_params()),
108
        (lambda: generate_nullable_arrow_table(), dummy_dataset_params()),
109
110
111
112
        (lambda: generate_random_arrow_table(3, 1000, 42), {}),
        (lambda: generate_random_arrow_table(100, 10000, 43), {}),
    ],
)
113
def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):
114
115
116
117
118
119
120
121
122
123
124
    arrow_table = arrow_table_fn()

    arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params)
    arrow_dataset.construct()

    pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), params=dataset_params)
    pandas_dataset.construct()

    arrow_dataset._dump_text(tmp_path / "arrow.txt")
    pandas_dataset._dump_text(tmp_path / "pandas.txt")
    assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")
125
126


127
128
129
130
131
132
133
# -------------------------------------------- FIELDS ------------------------------------------- #


def test_dataset_construct_fields_fuzzy():
    arrow_table = generate_random_arrow_table(3, 1000, 42)
    arrow_labels = generate_random_arrow_array(1000, 42)
    arrow_weights = generate_random_arrow_array(1000, 42)
134
    arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())
135

136
137
138
    arrow_dataset = lgb.Dataset(
        arrow_table, label=arrow_labels, weight=arrow_weights, group=arrow_groups
    )
139
140
141
    arrow_dataset.construct()

    pandas_dataset = lgb.Dataset(
142
143
144
145
        arrow_table.to_pandas(),
        label=arrow_labels.to_numpy(),
        weight=arrow_weights.to_numpy(),
        group=arrow_groups.to_numpy(),
146
147
148
149
    )
    pandas_dataset.construct()

    # Check for equality
150
    for field in ("label", "weight", "group"):
151
152
153
154
155
156
157
158
159
160
        np_assert_array_equal(
            arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
        )
    np_assert_array_equal(arrow_dataset.get_label(), pandas_dataset.get_label(), strict=True)
    np_assert_array_equal(arrow_dataset.get_weight(), pandas_dataset.get_weight(), strict=True)


# -------------------------------------------- LABELS ------------------------------------------- #


161
162
163
164
@pytest.mark.parametrize(
    ["array_type", "label_data"],
    [(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
)
165
166
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES)
def test_dataset_construct_labels(array_type, label_data, arrow_type):
167
168
169
170
171
172
    data = generate_dummy_arrow_table()
    labels = array_type(label_data, type=arrow_type)
    dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
    dataset.construct()

    expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
173
    np_assert_array_equal(expected, dataset.get_label(), strict=True)
174
175


176
# ------------------------------------------- WEIGHTS ------------------------------------------- #
177
178


179
180
181
182
183
184
185
186
187
188
189
190
191
def test_dataset_construct_weights_none():
    data = generate_dummy_arrow_table()
    weight = pa.array([1, 1, 1, 1, 1])
    dataset = lgb.Dataset(data, weight=weight, params=dummy_dataset_params())
    dataset.construct()
    assert dataset.get_weight() is None
    assert dataset.get_field("weight") is None


@pytest.mark.parametrize(
    ["array_type", "weight_data"],
    [(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
)
192
@pytest.mark.parametrize("arrow_type", _FLOAT_TYPES)
193
def test_dataset_construct_weights(array_type, weight_data, arrow_type):
194
195
196
197
    data = generate_dummy_arrow_table()
    weights = array_type(weight_data, type=arrow_type)
    dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
    dataset.construct()
198

199
200
    expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
    np_assert_array_equal(expected, dataset.get_weight(), strict=True)
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223


# -------------------------------------------- GROUPS ------------------------------------------- #


@pytest.mark.parametrize(
    ["array_type", "group_data"],
    [
        (pa.array, [2, 3]),
        (pa.chunked_array, [[2], [3]]),
        (pa.chunked_array, [[], [2, 3]]),
        (pa.chunked_array, [[2], [], [3], []]),
    ],
)
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES)
def test_dataset_construct_groups(array_type, group_data, arrow_type):
    data = generate_dummy_arrow_table()
    groups = array_type(group_data, type=arrow_type)
    dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params())
    dataset.construct()

    expected = np.array([0, 2, 5], dtype=np.int32)
    np_assert_array_equal(expected, dataset.get_field("group"), strict=True)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266


# ----------------------------------------- INIT SCORES ----------------------------------------- #


@pytest.mark.parametrize(
    ["array_type", "init_score_data"],
    [
        (pa.array, [0, 1, 2, 3, 3]),
        (pa.chunked_array, [[0, 1, 2], [3, 3]]),
        (pa.chunked_array, [[], [0, 1, 2], [3, 3]]),
        (pa.chunked_array, [[0, 1], [], [], [2], [3, 3], []]),
    ],
)
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES)
def test_dataset_construct_init_scores_array(
    array_type: Any, init_score_data: Any, arrow_type: Any
):
    data = generate_dummy_arrow_table()
    init_scores = array_type(init_score_data, type=arrow_type)
    dataset = lgb.Dataset(data, init_score=init_scores, params=dummy_dataset_params())
    dataset.construct()

    expected = np.array([0, 1, 2, 3, 3], dtype=np.float64)
    np_assert_array_equal(expected, dataset.get_init_score(), strict=True)


def test_dataset_construct_init_scores_table():
    data = generate_dummy_arrow_table()
    init_scores = pa.Table.from_arrays(
        [
            generate_random_arrow_array(5, seed=1),
            generate_random_arrow_array(5, seed=2),
            generate_random_arrow_array(5, seed=3),
        ],
        names=["a", "b", "c"],
    )
    dataset = lgb.Dataset(data, init_score=init_scores, params=dummy_dataset_params())
    dataset.construct()

    actual = dataset.get_init_score()
    expected = init_scores.to_pandas().to_numpy().astype(np.float64)
    np_assert_array_equal(expected, actual, strict=True)