Unimernet.py 7.83 KB
Newer Older
1
import argparse
2
import os
3
4
5
import re

import torch
6
7
8
import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torchvision import transforms
from unimernet.common.config import Config
from unimernet.processors import load_processor


class MathDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # if not pil image, then convert to pil image
        if isinstance(self.image_paths[idx], str):
            raw_image = Image.open(self.image_paths[idx])
        else:
            raw_image = self.image_paths[idx]
        if self.transform:
            image = self.transform(raw_image)
            return image


def latex_rm_whitespace(s: str):
34
35
36
37
38
    """Remove unnecessary whitespace from LaTeX code."""
    text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
    letter = "[a-zA-Z]"
    noletter = "[\W_^\d]"
    names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
39
40
41
42
    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
    news = s
    while True:
        s = news
43
44
45
        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
        news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
46
47
48
49
50
51
        if news == s:
            break
    return s


class UnimernetModel(object):
52
    def __init__(self, weight_dir, cfg_path, _device_="cpu"):
53
54
55
56
57
58
59
60
61
62
        args = argparse.Namespace(cfg_path=cfg_path, options=None)
        cfg = Config(args)
        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
        cfg.config.model.model_config.model_name = weight_dir
        cfg.config.model.tokenizer_config.path = weight_dir
        task = tasks.setup_task(cfg)
        self.model = task.build_model(cfg)
        self.device = _device_
        self.model.to(_device_)
        self.model.eval()
63
64
65
66
67
68
69
70
71
        vis_processor = load_processor(
            "formula_image_eval",
            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
        )
        self.mfr_transform = transforms.Compose(
            [
                vis_processor,
            ]
        )
72
73
74
75

    def predict(self, mfd_res, image):
        formula_list = []
        mf_image_list = []
76
77
78
        for xyxy, conf, cla in zip(
            mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
        ):
79
80
            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
            new_item = {
81
82
83
84
                "category_id": 13 + int(cla.item()),
                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                "score": round(float(conf.item()), 2),
                "latex": "",
85
86
87
88
89
90
91
            }
            formula_list.append(new_item)
            pil_img = Image.fromarray(image)
            bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
            mf_image_list.append(bbox_img)

        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
92
        dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
93
94
95
96
        mfr_res = []
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
97
98
                output = self.model.generate({"image": mf_img})
            mfr_res.extend(output["pred_str"])
99
        for res, latex in zip(formula_list, mfr_res):
100
            res["latex"] = latex_rm_whitespace(latex)
101
102
        return formula_list

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    # def batch_predict(
    #     self, images_mfd_res: list, images: list, batch_size: int = 64
    # ) -> list:
    #     images_formula_list = []
    #     mf_image_list = []
    #     backfill_list = []
    #     for image_index in range(len(images_mfd_res)):
    #         mfd_res = images_mfd_res[image_index]
    #         pil_img = Image.fromarray(images[image_index])
    #         formula_list = []
    #
    #         for xyxy, conf, cla in zip(
    #             mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
    #         ):
    #             xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
    #             new_item = {
    #                 "category_id": 13 + int(cla.item()),
    #                 "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
    #                 "score": round(float(conf.item()), 2),
    #                 "latex": "",
    #             }
    #             formula_list.append(new_item)
    #             bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
    #             mf_image_list.append(bbox_img)
    #
    #         images_formula_list.append(formula_list)
    #         backfill_list += formula_list
    #
    #     dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
    #     dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
    #     mfr_res = []
    #     for mf_img in dataloader:
    #         mf_img = mf_img.to(self.device)
    #         with torch.no_grad():
    #             output = self.model.generate({"image": mf_img})
    #         mfr_res.extend(output["pred_str"])
    #     for res, latex in zip(backfill_list, mfr_res):
    #         res["latex"] = latex_rm_whitespace(latex)
    #     return images_formula_list

    def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
144
145
146
        images_formula_list = []
        mf_image_list = []
        backfill_list = []
147
148
149
        image_info = []  # Store (area, original_index, image) tuples

        # Collect images with their original indices
150
151
152
153
154
        for image_index in range(len(images_mfd_res)):
            mfd_res = images_mfd_res[image_index]
            pil_img = Image.fromarray(images[image_index])
            formula_list = []

155
156
157
            for idx, (xyxy, conf, cla) in enumerate(zip(
                    mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
            )):
158
159
160
161
162
163
164
165
166
                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                new_item = {
                    "category_id": 13 + int(cla.item()),
                    "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                    "score": round(float(conf.item()), 2),
                    "latex": "",
                }
                formula_list.append(new_item)
                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
167
168
169
170
                area = (xmax - xmin) * (ymax - ymin)

                curr_idx = len(mf_image_list)
                image_info.append((area, curr_idx, bbox_img))
171
172
173
174
                mf_image_list.append(bbox_img)

            images_formula_list.append(formula_list)
            backfill_list += formula_list
175

176
177
178
179
180
181
182
183
184
185
        # Stable sort by area
        image_info.sort(key=lambda x: x[0])  # sort by area
        sorted_indices = [x[1] for x in image_info]
        sorted_images = [x[2] for x in image_info]

        # Create mapping for results
        index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}

        # Create dataset with sorted images
        dataset = MathDataset(sorted_images, transform=self.mfr_transform)
186
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
187
188

        # Process batches and store results
189
190
191
192
193
194
        mfr_res = []
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
                output = self.model.generate({"image": mf_img})
            mfr_res.extend(output["pred_str"])
195
196
197
198
199
200
201
202
203
204
205

        # Restore original order
        unsorted_results = [""] * len(mfr_res)
        for new_idx, latex in enumerate(mfr_res):
            original_idx = index_mapping[new_idx]
            unsorted_results[original_idx] = latex_rm_whitespace(latex)

        # Fill results back
        for res, latex in zip(backfill_list, unsorted_results):
            res["latex"] = latex

206
        return images_formula_list