RapidTable.py 5.83 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import html
import os
import time
from pathlib import Path
from typing import List

import cv2
import numpy as np
from loguru import logger
from rapid_table import ModelType, RapidTable, RapidTableInput
from rapid_table.utils import RapidTableOutput
from tqdm import tqdm

from mineru.model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path


def escape_html(input_string):
    """Escape HTML Entities."""
    return html.escape(input_string)


class CustomRapidTable(RapidTable):
    def __init__(self, cfg: RapidTableInput):
        import logging
        # 通过环境变量控制日志级别
        logging.disable(logging.INFO)
        super().__init__(cfg)
    def __call__(self, img_contents, ocr_results=None, batch_size=1):
        if not isinstance(img_contents, list):
            img_contents = [img_contents]

        s = time.perf_counter()

        results = RapidTableOutput()

        total_nums = len(img_contents)

        with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
            for start_i in range(0, total_nums, batch_size):
                end_i = min(total_nums, start_i + batch_size)

                imgs = self._load_imgs(img_contents[start_i:end_i])

                pred_structures, cell_bboxes = self.table_structure(imgs)
                logic_points = self.table_matcher.decode_logic_points(pred_structures)

                dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
                pred_htmls = self.table_matcher(
                    pred_structures, cell_bboxes, dt_boxes, rec_res
                )

                results.pred_htmls.extend(pred_htmls)
                # 更新进度条
                pbar.update(end_i - start_i)

        elapse = time.perf_counter() - s
        results.elapse = elapse / total_nums
        return results


class RapidTableModel():
    def __init__(self, ocr_engine):
        slanet_plus_model_path = os.path.join(
            auto_download_and_get_model_root_path(ModelPath.slanet_plus),
            ModelPath.slanet_plus,
        )
        input_args = RapidTableInput(
            model_type=ModelType.SLANETPLUS,
            model_dir_or_path=slanet_plus_model_path,
            use_ocr=False
        )
        self.table_model = CustomRapidTable(input_args)
        self.ocr_engine = ocr_engine

    def predict(self, image, ocr_result=None):
        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
        # Continue with OCR on potentially rotated image

        if not ocr_result:
            raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
            # 分离边界框、文本和置信度
            boxes = []
            texts = []
            scores = []
            for item in raw_ocr_result:
                if len(item) == 3:
                    boxes.append(item[0])
                    texts.append(escape_html(item[1]))
                    scores.append(item[2])
                elif len(item) == 2 and isinstance(item[1], tuple):
                    boxes.append(item[0])
                    texts.append(escape_html(item[1][0]))
                    scores.append(item[1][1])
            # 按照 rapid_table 期望的格式构建 ocr_results
            ocr_result = [(boxes, texts, scores)]

        if ocr_result:
            try:
                table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
                html_code = table_results.pred_htmls
                table_cell_bboxes = table_results.cell_bboxes
                logic_points = table_results.logic_points
                elapse = table_results.elapse
                return html_code, table_cell_bboxes, logic_points, elapse
            except Exception as e:
                logger.exception(e)

        return None, None, None, None

    def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
        not_none_table_res_list = []
        for table_res in table_res_list:
            if table_res.get("ocr_result", None):
                not_none_table_res_list.append(table_res)

        if not_none_table_res_list:
            img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
            ocr_results = []
            # ocr_results需要按照rapid_table期望的格式构建
            for table_res in not_none_table_res_list:
                raw_ocr_result = table_res["ocr_result"]
                boxes = []
                texts = []
                scores = []
                for item in raw_ocr_result:
                    if len(item) == 3:
                        boxes.append(item[0])
                        texts.append(escape_html(item[1]))
                        scores.append(item[2])
                    elif len(item) == 2 and isinstance(item[1], tuple):
                        boxes.append(item[0])
                        texts.append(escape_html(item[1][0]))
                        scores.append(item[1][1])
                ocr_results.append((boxes, texts, scores))
            table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)

            for i, result in enumerate(table_results.pred_htmls):
                if result:
                    not_none_table_res_list[i]['table_res']['html'] = result

if __name__ == '__main__':
    ocr_engine= PytorchPaddleOCR(
            det_db_box_thresh=0.5,
            det_db_unclip_ratio=1.6,
            enable_merge_det_boxes=False,
    )
    table_model = RapidTableModel(ocr_engine)
    img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
    image = cv2.imread(str(img_path))
    html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
    print(html_code)