Unimernet.py 6.09 KB
Newer Older
1
import argparse
2
import os
3
4
5
import re

import torch
6
7
import unimernet.tasks as tasks
from torch.utils.data import DataLoader, Dataset
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torchvision import transforms
from unimernet.common.config import Config
from unimernet.processors import load_processor


class MathDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # if not pil image, then convert to pil image
        if isinstance(self.image_paths[idx], str):
            raw_image = Image.open(self.image_paths[idx])
        else:
            raw_image = self.image_paths[idx]
        if self.transform:
            image = self.transform(raw_image)
            return image


def latex_rm_whitespace(s: str):
33
34
35
36
37
    """Remove unnecessary whitespace from LaTeX code."""
    text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
    letter = "[a-zA-Z]"
    noletter = "[\W_^\d]"
    names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
38
39
40
41
    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
    news = s
    while True:
        s = news
42
43
44
        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
        news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
45
46
47
48
49
50
        if news == s:
            break
    return s


class UnimernetModel(object):
51
    def __init__(self, weight_dir, cfg_path, _device_="cpu"):
52
53
54
55
56
57
58
59
60
61
        args = argparse.Namespace(cfg_path=cfg_path, options=None)
        cfg = Config(args)
        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
        cfg.config.model.model_config.model_name = weight_dir
        cfg.config.model.tokenizer_config.path = weight_dir
        task = tasks.setup_task(cfg)
        self.model = task.build_model(cfg)
        self.device = _device_
        self.model.to(_device_)
        self.model.eval()
62
63
64
65
66
67
68
69
70
        vis_processor = load_processor(
            "formula_image_eval",
            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
        )
        self.mfr_transform = transforms.Compose(
            [
                vis_processor,
            ]
        )
71
72
73
74

    def predict(self, mfd_res, image):
        formula_list = []
        mf_image_list = []
75
76
77
        for xyxy, conf, cla in zip(
            mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
        ):
78
79
            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
            new_item = {
80
81
82
83
                "category_id": 13 + int(cla.item()),
                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                "score": round(float(conf.item()), 2),
                "latex": "",
84
85
86
87
88
89
90
            }
            formula_list.append(new_item)
            pil_img = Image.fromarray(image)
            bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
            mf_image_list.append(bbox_img)

        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
91
        dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
92
93
94
95
        mfr_res = []
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
96
97
                output = self.model.generate({"image": mf_img})
            mfr_res.extend(output["pred_str"])
98
        for res, latex in zip(formula_list, mfr_res):
99
            res["latex"] = latex_rm_whitespace(latex)
100
101
        return formula_list

102
103

    def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
104
105
106
        images_formula_list = []
        mf_image_list = []
        backfill_list = []
107
108
109
        image_info = []  # Store (area, original_index, image) tuples

        # Collect images with their original indices
110
111
        for image_index in range(len(images_mfd_res)):
            mfd_res = images_mfd_res[image_index]
112
            np_array_image = images[image_index]
113
114
            formula_list = []

115
116
117
            for idx, (xyxy, conf, cla) in enumerate(zip(
                    mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
            )):
118
119
120
121
122
123
124
125
                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                new_item = {
                    "category_id": 13 + int(cla.item()),
                    "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                    "score": round(float(conf.item()), 2),
                    "latex": "",
                }
                formula_list.append(new_item)
126
                bbox_img = np_array_image[ymin:ymax, xmin:xmax]
127
128
129
130
                area = (xmax - xmin) * (ymax - ymin)

                curr_idx = len(mf_image_list)
                image_info.append((area, curr_idx, bbox_img))
131
132
133
134
                mf_image_list.append(bbox_img)

            images_formula_list.append(formula_list)
            backfill_list += formula_list
135

136
137
138
139
140
141
142
143
144
145
        # Stable sort by area
        image_info.sort(key=lambda x: x[0])  # sort by area
        sorted_indices = [x[1] for x in image_info]
        sorted_images = [x[2] for x in image_info]

        # Create mapping for results
        index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}

        # Create dataset with sorted images
        dataset = MathDataset(sorted_images, transform=self.mfr_transform)
146
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
147
148

        # Process batches and store results
149
150
151
152
153
154
        mfr_res = []
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
                output = self.model.generate({"image": mf_img})
            mfr_res.extend(output["pred_str"])
155
156
157
158
159
160
161
162
163
164
165

        # Restore original order
        unsorted_results = [""] * len(mfr_res)
        for new_idx, latex in enumerate(mfr_res):
            original_idx = index_mapping[new_idx]
            unsorted_results[original_idx] = latex_rm_whitespace(latex)

        # Fill results back
        for res, latex in zip(backfill_list, unsorted_results):
            res["latex"] = latex

166
        return images_formula_list