"vscode:/vscode.git/clone" did not exist on "085691391bd6ffe2844c6e6ff3201a8efee19e64"
test_array_readwriter.py 764 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
import tempfile

import numpy as np
import pytest
from distpartitioning import array_readwriter


@pytest.mark.parametrize(
    "shape", [[500], [300, 10], [200, 5, 5], [100, 5, 5, 5]]
)
@pytest.mark.parametrize("format", ["numpy", "parquet"])
def test_array_readwriter(format, shape):
    original_array = np.random.rand(*shape)
    fmt_meta = {"name": format}

    with tempfile.TemporaryDirectory() as test_dir:
        path = os.path.join(test_dir, f"nodes.{format}")
        array_readwriter.get_array_parser(**fmt_meta).write(
            path, original_array
        )
        array = array_readwriter.get_array_parser(**fmt_meta).read(path)

        assert original_array.shape == array.shape
        assert np.array_equal(original_array, array)