parquet.py 2.25 KB
Newer Older
1
2
import logging

3
import numpy as np
4
5
6
7
8
9
10
11
12
13
14
15
16
import pandas as pd
import pyarrow
import pyarrow.parquet

from .registry import register_array_parser


@register_array_parser("parquet")
class ParquetArrayParser(object):
    def __init__(self):
        pass

    def read(self, path):
17
        logging.debug("Reading from %s using parquet format" % path)
18
19
        metadata = pyarrow.parquet.read_metadata(path)
        metadata = metadata.schema.to_arrow_schema().metadata
20

21
22
        # As parquet data are tabularized, we assume the dim of ndarray is 2.
        # If not, it should be explictly specified in the file as metadata.
23
24
25
26
        if metadata:
            shape = metadata.get(b"shape", None)
        else:
            shape = None
27
        table = pyarrow.parquet.read_table(path, memory_map=True)
28
29
30
31
32

        data_types = table.schema.types
        # Spark ML feature processing produces single-column parquet files where each row is a vector object
        if len(data_types) == 1 and isinstance(data_types[0], pyarrow.ListType):
            arr = np.array(table.to_pandas().iloc[:, 0].to_list())
33
34
35
            logging.debug(
                f"Parquet data under {path} converted from single vector per row to ndarray"
            )
36
37
        else:
            arr = table.to_pandas().to_numpy()
38
        if not shape:
39
            logging.debug(
40
41
42
                "Shape information not found in the metadata, read the data as "
                "a 2 dim array."
            )
43
        logging.debug("Done reading from %s" % path)
44
45
46
        shape = tuple(eval(shape.decode())) if shape else arr.shape
        return arr.reshape(shape)

47
    def write(self, path, array, vector_rows=False):
48
        logging.debug("Writing to %s using parquet format" % path)
49
50
51
        shape = array.shape
        if len(shape) > 2:
            array = array.reshape(shape[0], -1)
52
53
        if vector_rows:
            table = pyarrow.table(
54
55
                [pyarrow.array(array.tolist())], names=["vector"]
            )
56
            logging.debug("Writing to %s using single-vector rows..." % path)
57
58
59
60
        else:
            table = pyarrow.Table.from_pandas(pd.DataFrame(array))
            table = table.replace_schema_metadata({"shape": str(shape)})

61
        pyarrow.parquet.write_table(table, path)
62
        logging.debug("Done writing to %s" % path)