Unimernet.py 5.17 KB
Newer Older
1
import argparse
2
import os
3
4
5
import re

import torch
6
7
8
import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torchvision import transforms
from unimernet.common.config import Config
from unimernet.processors import load_processor


class MathDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # if not pil image, then convert to pil image
        if isinstance(self.image_paths[idx], str):
            raw_image = Image.open(self.image_paths[idx])
        else:
            raw_image = self.image_paths[idx]
        if self.transform:
            image = self.transform(raw_image)
            return image


def latex_rm_whitespace(s: str):
34
35
36
37
38
    """Remove unnecessary whitespace from LaTeX code."""
    text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
    letter = "[a-zA-Z]"
    noletter = "[\W_^\d]"
    names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
39
40
41
42
    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
    news = s
    while True:
        s = news
43
44
45
        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
        news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
46
47
48
49
50
51
        if news == s:
            break
    return s


class UnimernetModel(object):
52
    def __init__(self, weight_dir, cfg_path, _device_="cpu"):
53
54
55
56
57
58
59
60
61
62
        args = argparse.Namespace(cfg_path=cfg_path, options=None)
        cfg = Config(args)
        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
        cfg.config.model.model_config.model_name = weight_dir
        cfg.config.model.tokenizer_config.path = weight_dir
        task = tasks.setup_task(cfg)
        self.model = task.build_model(cfg)
        self.device = _device_
        self.model.to(_device_)
        self.model.eval()
63
64
65
66
67
68
69
70
71
        vis_processor = load_processor(
            "formula_image_eval",
            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
        )
        self.mfr_transform = transforms.Compose(
            [
                vis_processor,
            ]
        )
72
73
74
75

    def predict(self, mfd_res, image):
        formula_list = []
        mf_image_list = []
76
77
78
        for xyxy, conf, cla in zip(
            mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
        ):
79
80
            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
            new_item = {
81
82
83
84
                "category_id": 13 + int(cla.item()),
                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                "score": round(float(conf.item()), 2),
                "latex": "",
85
86
87
88
89
90
91
            }
            formula_list.append(new_item)
            pil_img = Image.fromarray(image)
            bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
            mf_image_list.append(bbox_img)

        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
92
        dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
93
94
95
96
        mfr_res = []
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
97
98
                output = self.model.generate({"image": mf_img})
            mfr_res.extend(output["pred_str"])
99
        for res, latex in zip(formula_list, mfr_res):
100
            res["latex"] = latex_rm_whitespace(latex)
101
102
        return formula_list

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def batch_predict(
        self, images_mfd_res: list, images: list, batch_size: int = 64
    ) -> list:
        images_formula_list = []
        mf_image_list = []
        backfill_list = []
        for image_index in range(len(images_mfd_res)):
            mfd_res = images_mfd_res[image_index]
            pil_img = Image.fromarray(images[image_index])
            formula_list = []

            for xyxy, conf, cla in zip(
                mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
            ):
                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                new_item = {
                    "category_id": 13 + int(cla.item()),
                    "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                    "score": round(float(conf.item()), 2),
                    "latex": "",
                }
                formula_list.append(new_item)
                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
                mf_image_list.append(bbox_img)

            images_formula_list.append(formula_list)
            backfill_list += formula_list
130

131
132
133
134
135
136
137
138
139
140
141
        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
        mfr_res = []
        for mf_img in dataloader:
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
                output = self.model.generate({"image": mf_img})
            mfr_res.extend(output["pred_str"])
        for res, latex in zip(backfill_list, mfr_res):
            res["latex"] = latex_rm_whitespace(latex)
        return images_formula_list