magic_model.py 5.37 KB
Newer Older
1
2
3
4
5
6
7
import json

from magic_pdf.libs.commons import fitz
from loguru import logger

from magic_pdf.libs.commons import join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio
赵小蒙's avatar
赵小蒙 committed
8
from magic_pdf.libs.ocr_content_type import ContentType
9
10
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
kernel.h@qq.com's avatar
kernel.h@qq.com committed
11
12
13
14
15
16
17


class MagicModel():
    """
    每个函数没有得到元素的时候返回空list
    
    """
18
19

    def __fix_axis(self):
赵小蒙's avatar
赵小蒙 committed
20
        need_remove_list = []
21
22
23
24
25
26
27
28
29
30
31
32
33
        for model_page_info in self.__model_list:
            page_no = model_page_info['page_info']['page_no']
            horizontal_scale_ratio, vertical_scale_ratio = get_scale_ratio(model_page_info, self.__docs[page_no])
            layout_dets = model_page_info["layout_dets"]
            for layout_det in layout_dets:
                x0, y0, _, _, x1, y1, _, _ = layout_det["poly"]
                bbox = [
                    int(x0 / horizontal_scale_ratio),
                    int(y0 / vertical_scale_ratio),
                    int(x1 / horizontal_scale_ratio),
                    int(y1 / vertical_scale_ratio),
                ]
                layout_det["bbox"] = bbox
赵小蒙's avatar
赵小蒙 committed
34
35
36
37
38
39
                # 删除高度或者宽度为0的spans
                if bbox[2] - bbox[0] == 0 or bbox[3] - bbox[1] == 0:
                    need_remove_list.append(layout_det)
            for need_remove in need_remove_list:
                layout_dets.remove(need_remove)

40
41

    def __init__(self, model_list: list, docs: fitz.Document):
kernel.h@qq.com's avatar
kernel.h@qq.com committed
42
        self.__model_list = model_list
43
        self.__docs = docs
kernel.h@qq.com's avatar
kernel.h@qq.com committed
44
        self.__fix_axis()
45
46
47

    def get_imgs(self, page_no: int):  # @许瑞

kernel.h@qq.com's avatar
kernel.h@qq.com committed
48
        image_block = {
49

kernel.h@qq.com's avatar
kernel.h@qq.com committed
50
        }
51
        image_block['bbox'] = [x0, y0, x1, y1]  # 计算出来
kernel.h@qq.com's avatar
update  
kernel.h@qq.com committed
52
        image_block['img_body_bbox'] = [x0, y0, x1, y1]
53
54
55
56
57
58
59
60
        image_blcok['img_caption_bbox'] = [x0, y0, x1, y1]  # 如果没有就是None,但是保证key存在

        return [image_block, ]

    def get_tables(self, page_no: int) -> list:  # 3个坐标, caption, table主体,table-note
        pass  # 许瑞, 结构和image一样

    def get_equations(self, page_no: int) -> list:  # 有坐标,也有字
kernel.h@qq.com's avatar
kernel.h@qq.com committed
61
        return inline_equations, interline_equations  # @凯文
62
63
64
65
66
67
68
69
70
71
72

    def get_discarded(self, page_no: int) -> list:  # 自研模型,只有坐标
        pass  # @凯文

    def get_text_blocks(self, page_no: int) -> list:  # 自研模型搞的,只有坐标,没有字
        pass  # @凯文

    def get_title_blocks(self, page_no: int) -> list:  # 自研模型,只有坐标,没字
        pass  # @凯文

    def get_ocr_text(self, page_no: int) -> list:  # paddle 搞的,有字也有坐标
赵小蒙's avatar
赵小蒙 committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        text_spans = []
        model_page_info = self.__model_list[page_no]
        layout_dets = model_page_info["layout_dets"]
        for layout_det in layout_dets:
            if layout_det["category_id"] == "15":
                span = {
                    "bbox": layout_det['bbox'],
                    "content": layout_det["text"],
                }
                text_spans.append(span)
        return text_spans

    def get_all_spans(self, page_no: int) -> list:
        all_spans = []
        model_page_info = self.__model_list[page_no]
        layout_dets = model_page_info["layout_dets"]
        allow_category_id_list = [3, 5, 13, 14, 15]
        """当成span拼接的"""
        #  3: 'image', # 图片
        #  4: 'table',       # 表格
        #  13: 'inline_equation',     # 行内公式
        #  14: 'interline_equation',      # 行间公式
        #  15: 'text',      # ocr识别文本
        for layout_det in layout_dets:
            category_id = layout_det["category_id"]
            if category_id in allow_category_id_list:
                span = {
                    "bbox": layout_det['bbox']
                }
                if category_id == 3:
                    span["type"] = ContentType.Image
                elif category_id == 5:
                    span["type"] = ContentType.Table
                elif category_id == 13:
                    span["content"] = layout_det["latex"]
                    span["type"] = ContentType.InlineEquation
                elif category_id == 14:
                    span["content"] = layout_det["latex"]
                    span["type"] = ContentType.InterlineEquation
                elif category_id == 15:
                    span["content"] = layout_det["text"]
                    span["type"] = ContentType.Text
                all_spans.append(span)
        return all_spans

    def get_page_size(self, page_no: int):  # 获取页面宽高
        # 获取当前页的page对象
        page = self.__docs[page_no]
        # 获取当前页的宽高
        page_w = page.rect.width
        page_h = page.rect.height
        return page_w, page_h
125
126
127
128
129
130
131
132
133
134
135
136
137
138


if __name__ == '__main__':
    drw = DiskReaderWriter(r"D:/project/20231108code-clean")
    pdf_file_path = r"linshixuqiu\19983-00.pdf"
    model_file_path = r"linshixuqiu\19983-00_new.json"
    pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
    model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
    model_list = json.loads(model_json_txt)
    write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00"
    img_bucket_path = "imgs"
    img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
    pdf_docs = fitz.open("pdf", pdf_bytes)
    magic_model = MagicModel(model_list, pdf_docs)