Unimernet.py 5.69 KB
Newer Older
1
import argparse
2
import os
3
4
5
import re

import torch
6
7
import unimernet.tasks as tasks
from torch.utils.data import DataLoader, Dataset
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torchvision import transforms
from unimernet.common.config import Config
from unimernet.processors import load_processor


class MathDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)


def latex_rm_whitespace(s: str):
23
24
25
26
27
    """Remove unnecessary whitespace from LaTeX code."""
    text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
    letter = "[a-zA-Z]"
    noletter = "[\W_^\d]"
    names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
28
29
30
31
    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
    news = s
    while True:
        s = news
32
33
34
        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
        news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
35
36
37
38
39
40
        if news == s:
            break
    return s


class UnimernetModel(object):
41
    def __init__(self, weight_dir, cfg_path, _device_="cpu"):
42
43
44
45
46
47
48
49
50
51
        args = argparse.Namespace(cfg_path=cfg_path, options=None)
        cfg = Config(args)
        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
        cfg.config.model.model_config.model_name = weight_dir
        cfg.config.model.tokenizer_config.path = weight_dir
        task = tasks.setup_task(cfg)
        self.model = task.build_model(cfg)
        self.device = _device_
        self.model.to(_device_)
        self.model.eval()
52
53
54
55
56
57
58
59
60
        vis_processor = load_processor(
            "formula_image_eval",
            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
        )
        self.mfr_transform = transforms.Compose(
            [
                vis_processor,
            ]
        )
61
62
63
64

    def predict(self, mfd_res, image):
        formula_list = []
        mf_image_list = []
65
66
67
        for xyxy, conf, cla in zip(
            mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
        ):
68
69
            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
            new_item = {
70
71
72
73
                "category_id": 13 + int(cla.item()),
                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                "score": round(float(conf.item()), 2),
                "latex": "",
74
75
            }
            formula_list.append(new_item)
76
            bbox_img = image[ymin:ymax, xmin:xmax]
77
78
79
            mf_image_list.append(bbox_img)

        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
80
        dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
81
82
83
84
        mfr_res = []
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
85
86
                output = self.model.generate({"image": mf_img})
            mfr_res.extend(output["pred_str"])
87
        for res, latex in zip(formula_list, mfr_res):
88
            res["latex"] = latex_rm_whitespace(latex)
89
90
        return formula_list

91
    def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
92
93
94
        images_formula_list = []
        mf_image_list = []
        backfill_list = []
95
96
97
        image_info = []  # Store (area, original_index, image) tuples

        # Collect images with their original indices
98
99
        for image_index in range(len(images_mfd_res)):
            mfd_res = images_mfd_res[image_index]
100
            np_array_image = images[image_index]
101
102
            formula_list = []

103
104
105
            for idx, (xyxy, conf, cla) in enumerate(zip(
                    mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
            )):
106
107
108
109
110
111
112
113
                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                new_item = {
                    "category_id": 13 + int(cla.item()),
                    "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                    "score": round(float(conf.item()), 2),
                    "latex": "",
                }
                formula_list.append(new_item)
114
                bbox_img = np_array_image[ymin:ymax, xmin:xmax]
115
116
117
118
                area = (xmax - xmin) * (ymax - ymin)

                curr_idx = len(mf_image_list)
                image_info.append((area, curr_idx, bbox_img))
119
120
121
122
                mf_image_list.append(bbox_img)

            images_formula_list.append(formula_list)
            backfill_list += formula_list
123

124
125
126
127
128
129
130
131
132
133
        # Stable sort by area
        image_info.sort(key=lambda x: x[0])  # sort by area
        sorted_indices = [x[1] for x in image_info]
        sorted_images = [x[2] for x in image_info]

        # Create mapping for results
        index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}

        # Create dataset with sorted images
        dataset = MathDataset(sorted_images, transform=self.mfr_transform)
134
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
135
136

        # Process batches and store results
137
138
139
140
141
142
        mfr_res = []
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
                output = self.model.generate({"image": mf_img})
            mfr_res.extend(output["pred_str"])
143
144
145
146
147
148
149
150
151
152
153

        # Restore original order
        unsorted_results = [""] * len(mfr_res)
        for new_idx, latex in enumerate(mfr_res):
            original_idx = index_mapping[new_idx]
            unsorted_results[original_idx] = latex_rm_whitespace(latex)

        # Fill results back
        for res, latex in zip(backfill_list, unsorted_results):
            res["latex"] = latex

154
        return images_formula_list