read_api.py 3.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import json
import os
from pathlib import Path

from magic_pdf.config.exceptions import EmptyData, InvalidParams
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
                                               MultiBucketS3DataReader)
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset


def read_jsonl(
    s3_path_or_local: str, s3_client: MultiBucketS3DataReader | None = None
) -> list[PymuDocDataset]:
    """Read the jsonl file and return the list of PymuDocDataset.

    Args:
        s3_path_or_local (str): local file or s3 path
        s3_client (MultiBucketS3DataReader | None, optional): s3 client that support multiple bucket. Defaults to None.

    Raises:
        InvalidParams: if s3_path_or_local is s3 path but s3_client is not provided.
        EmptyData: if no pdf file location is provided in some line of jsonl file.
        InvalidParams: if the file location is s3 path but s3_client is not provided

    Returns:
        list[PymuDocDataset]: each line in the jsonl file will be converted to a PymuDocDataset
    """
    bits_arr = []
    if s3_path_or_local.startswith('s3://'):
        if s3_client is None:
            raise InvalidParams('s3_client is required when s3_path is provided')
        jsonl_bits = s3_client.read(s3_path_or_local)
    else:
        jsonl_bits = FileBasedDataReader('').read(s3_path_or_local)
    jsonl_d = [
        json.loads(line) for line in jsonl_bits.decode().split('\n') if line.strip()
    ]
icecraft's avatar
icecraft committed
38
    for d in jsonl_d:
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        pdf_path = d.get('file_location', '') or d.get('path', '')
        if len(pdf_path) == 0:
            raise EmptyData('pdf file location is empty')
        if pdf_path.startswith('s3://'):
            if s3_client is None:
                raise InvalidParams('s3_client is required when s3_path is provided')
            bits_arr.append(s3_client.read(pdf_path))
        else:
            bits_arr.append(FileBasedDataReader('').read(pdf_path))
    return [PymuDocDataset(bits) for bits in bits_arr]


def read_local_pdfs(path: str) -> list[PymuDocDataset]:
    """Read pdf from path or directory.

    Args:
        path (str): pdf file path or directory that contains pdf files

    Returns:
        list[PymuDocDataset]: each pdf file will converted to a PymuDocDataset
    """
    if os.path.isdir(path):
xu rui's avatar
xu rui committed
61
        reader = FileBasedDataReader()
xu rui's avatar
xu rui committed
62
63
64
65
66
67
68
        ret = []
        for root, _, files in os.walk(path):
            for file in files:
                suffix = file.split('.')
                if suffix[-1] == 'pdf':
                    ret.append( PymuDocDataset(reader.read(os.path.join(root, file))))
        return ret
69
70
71
72
73
74
    else:
        reader = FileBasedDataReader()
        bits = reader.read(path)
        return [PymuDocDataset(bits)]


xu rui's avatar
xu rui committed
75
def read_local_images(path: str, suffixes: list[str]=[]) -> list[ImageDataset]:
76
77
78
79
80
81
82
83
84
85
86
87
    """Read images from path or directory.

    Args:
        path (str): image file path or directory that contains image files
        suffixes (list[str]): the suffixes of the image files used to filter the files. Example: ['jpg', 'png']

    Returns:
        list[ImageDataset]: each image file will converted to a ImageDataset
    """
    if os.path.isdir(path):
        imgs_bits = []
        s_suffixes = set(suffixes)
xu rui's avatar
xu rui committed
88
        reader = FileBasedDataReader()
89
90
91
92
        for root, _, files in os.walk(path):
            for file in files:
                suffix = file.split('.')
                if suffix[-1] in s_suffixes:
xu rui's avatar
xu rui committed
93
                    imgs_bits.append(reader.read(os.path.join(root, file)))
94
95
96
97
98
        return [ImageDataset(bits) for bits in imgs_bits]
    else:
        reader = FileBasedDataReader()
        bits = reader.read(path)
        return [ImageDataset(bits)]