api.py 2.45 KB
Newer Older
drunkpig's avatar
drunkpig committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from pathlib import Path

from loguru import logger

from magic_pdf.integrations.rag.type import (ElementRelation, LayoutElements,
                                             Node)
from magic_pdf.integrations.rag.utils import inference


class RagPageReader:

    def __init__(self, pagedata: LayoutElements):
        self.o = [
            Node(
                category_type=v.category_type,
                text=v.text,
                image_path=v.image_path,
                anno_id=v.anno_id,
                latex=v.latex,
                html=v.html,
            ) for v in pagedata.layout_dets
        ]

        self.pagedata = pagedata

    def __iter__(self):
        return iter(self.o)

    def get_rel_map(self) -> list[ElementRelation]:
        return self.pagedata.extra.element_relation


class RagDocumentReader:

    def __init__(self, ragdata: list[LayoutElements]):
        self.o = [RagPageReader(v) for v in ragdata]

    def __iter__(self):
        return iter(self.o)


class DataReader:

    def __init__(self, path_or_directory: str, method: str, output_dir: str):
        self.path_or_directory = path_or_directory
        self.method = method
        self.output_dir = output_dir
        self.pdfs = []
        if os.path.isdir(path_or_directory):
            for doc_path in Path(path_or_directory).glob('*.pdf'):
                self.pdfs.append(doc_path)
        else:
            assert path_or_directory.endswith('.pdf')
            self.pdfs.append(Path(path_or_directory))

    def get_documents_count(self) -> int:
        """Returns the number of documents in the directory."""
        return len(self.pdfs)

    def get_document_result(self, idx: int) -> RagDocumentReader | None:
        """
        Args:
            idx (int): the index of documents under the
                directory path_or_directory

        Returns:
            RagDocumentReader | None: RagDocumentReader is an iterable object,
            more details @RagDocumentReader
        """
        if idx >= self.get_documents_count() or idx < 0:
            logger.error(f'invalid idx: {idx}')
            return None
        res = inference(str(self.pdfs[idx]), self.output_dir, self.method)
        if res is None:
            logger.warning(f'failed to inference pdf {self.pdfs[idx]}')
            return None
        return RagDocumentReader(res)

    def get_document_filename(self, idx: int) -> Path:
        """get the filename of the document."""
        return self.pdfs[idx]