paddle_table_cls.py 5.65 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os

from PIL import Image
import cv2
import numpy as np
import onnxruntime
from loguru import logger
from tqdm import tqdm

from mineru.backend.pipeline.model_list import AtomicModel
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path


class PaddleTableClsModel:
    def __init__(self):
        self.sess = onnxruntime.InferenceSession(
            os.path.join(auto_download_and_get_model_root_path(ModelPath.paddle_table_cls), ModelPath.paddle_table_cls)
        )
        self.less_length = 256
        self.cw, self.ch = 224, 224
        self.std = [0.229, 0.224, 0.225]
        self.scale = 0.00392156862745098
        self.mean = [0.485, 0.456, 0.406]
        self.labels = [AtomicModel.WiredTable, AtomicModel.WirelessTable]

    def preprocess(self, input_img):
        # 放大图片,使其最短边长为256
        h, w = input_img.shape[:2]
        scale = 256 / min(h, w)
        h_resize = round(h * scale)
        w_resize = round(w * scale)
        img = cv2.resize(input_img, (w_resize, h_resize), interpolation=1)
        # 调整为224*224的正方形
        h, w = img.shape[:2]
        cw, ch = 224, 224
        x1 = max(0, (w - cw) // 2)
        y1 = max(0, (h - ch) // 2)
        x2 = min(w, x1 + cw)
        y2 = min(h, y1 + ch)
        if w < cw or h < ch:
            raise ValueError(
                f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
            )
        img = img[y1:y2, x1:x2, ...]
        # 正则化
        split_im = list(cv2.split(img))
        std = [0.229, 0.224, 0.225]
        scale = 0.00392156862745098
        mean = [0.485, 0.456, 0.406]
        alpha = [scale / std[i] for i in range(len(std))]
        beta = [-mean[i] / std[i] for i in range(len(std))]
        for c in range(img.shape[2]):
            split_im[c] = split_im[c].astype(np.float32)
            split_im[c] *= alpha[c]
            split_im[c] += beta[c]
        img = cv2.merge(split_im)
        # 5. 转换为 CHW 格式
        img = img.transpose((2, 0, 1))
        imgs = [img]
        x = np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)
        return x

    def predict(self, input_img):
        if isinstance(input_img, Image.Image):
            np_img = np.asarray(input_img)
        elif isinstance(input_img, np.ndarray):
            np_img = input_img
        else:
            raise ValueError("Input must be a pillow object or a numpy array.")
        x = self.preprocess(np_img)
        result = self.sess.run(None, {"x": x})
        idx = np.argmax(result)
        conf = float(np.max(result))
        return self.labels[idx], conf

    def list_2_batch(self, img_list, batch_size=16):
        """
        将任意长度的列表按照指定的batch size分成多个batch

        Args:
            img_list: 输入的列表
            batch_size: 每个batch的大小,默认为16

        Returns:
            一个包含多个batch的列表,每个batch都是原列表的一个子列表
        """
        batches = []
        for i in range(0, len(img_list), batch_size):
            batch = img_list[i : min(i + batch_size, len(img_list))]
            batches.append(batch)
        return batches

    def batch_preprocess(self, imgs):
        res_imgs = []
        for img in imgs:
            img = np.asarray(img)
            # 放大图片,使其最短边长为256
            h, w = img.shape[:2]
            scale = 256 / min(h, w)
            h_resize = round(h * scale)
            w_resize = round(w * scale)
            img = cv2.resize(img, (w_resize, h_resize), interpolation=1)
            # 调整为224*224的正方形
            h, w = img.shape[:2]
            cw, ch = 224, 224
            x1 = max(0, (w - cw) // 2)
            y1 = max(0, (h - ch) // 2)
            x2 = min(w, x1 + cw)
            y2 = min(h, y1 + ch)
            if w < cw or h < ch:
                raise ValueError(
                    f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
                )
            img = img[y1:y2, x1:x2, ...]
            # 正则化
            split_im = list(cv2.split(img))
            std = [0.229, 0.224, 0.225]
            scale = 0.00392156862745098
            mean = [0.485, 0.456, 0.406]
            alpha = [scale / std[i] for i in range(len(std))]
            beta = [-mean[i] / std[i] for i in range(len(std))]
            for c in range(img.shape[2]):
                split_im[c] = split_im[c].astype(np.float32)
                split_im[c] *= alpha[c]
                split_im[c] += beta[c]
            img = cv2.merge(split_im)
            # 5. 转换为 CHW 格式
            img = img.transpose((2, 0, 1))
            res_imgs.append(img)
        x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
        return x
    def batch_predict(self, img_info_list, batch_size=16):
        imgs = [item["wired_table_img"] for item in img_info_list]
        imgs = self.list_2_batch(imgs, batch_size=batch_size)
        label_res = []
        with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict", disable=True) as pbar:
            for img_batch in imgs:
                x = self.batch_preprocess(img_batch)
                result = self.sess.run(None, {"x": x})
                for img_res in result[0]:
                    idx = np.argmax(img_res)
                    conf = float(np.max(img_res))
                    label_res.append((self.labels[idx],conf))
                pbar.update(len(img_batch))
            for img_info, (label, conf) in zip(img_info_list, label_res):
                img_info['table_res']["cls_label"] = label
                img_info['table_res']["cls_score"] = round(conf, 3)