Unimernet.py 4.8 KB
Newer Older
1
import torch
2
from torch.utils.data import DataLoader, Dataset
3
4
5
6
7
8
9
10
11
12


class MathDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

13
14
15
16
17
    def __getitem__(self, idx):
        raw_image = self.image_paths[idx]
        if self.transform:
            image = self.transform(raw_image)
            return image
18
19
20


class UnimernetModel(object):
21
    def __init__(self, weight_dir, cfg_path, _device_="cpu"):
22
23
24
25
26
        from .unimernet_hf import UnimernetModel
        if _device_.startswith("mps"):
            self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
        else:
            self.model = UnimernetModel.from_pretrained(weight_dir)
27
28
        self.device = _device_
        self.model.to(_device_)
29
30
        if not _device_.startswith("cpu"):
            self.model = self.model.to(dtype=torch.float16)
31
        self.model.eval()
32

33
34
35
36

    def predict(self, mfd_res, image):
        formula_list = []
        mf_image_list = []
37
38
39
        for xyxy, conf, cla in zip(
            mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
        ):
40
41
            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
            new_item = {
42
43
44
45
                "category_id": 13 + int(cla.item()),
                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                "score": round(float(conf.item()), 2),
                "latex": "",
46
47
            }
            formula_list.append(new_item)
48
            bbox_img = image[ymin:ymax, xmin:xmax]
49
50
            mf_image_list.append(bbox_img)

51
        dataset = MathDataset(mf_image_list, transform=self.model.transform)
52
        dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
53
54
        mfr_res = []
        for mf_img in dataloader:
55
            mf_img = mf_img.to(dtype=self.model.dtype)
56
57
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
58
                output = self.model.generate({"image": mf_img})
59
            mfr_res.extend(output["fixed_str"])
60
        for res, latex in zip(formula_list, mfr_res):
61
            res["latex"] = latex
62
63
        return formula_list

64
    def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
65
66
67
        images_formula_list = []
        mf_image_list = []
        backfill_list = []
68
69
70
        image_info = []  # Store (area, original_index, image) tuples

        # Collect images with their original indices
71
72
        for image_index in range(len(images_mfd_res)):
            mfd_res = images_mfd_res[image_index]
73
            np_array_image = images[image_index]
74
75
            formula_list = []

76
77
78
            for idx, (xyxy, conf, cla) in enumerate(zip(
                    mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
            )):
79
80
81
82
83
84
85
86
                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                new_item = {
                    "category_id": 13 + int(cla.item()),
                    "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
                    "score": round(float(conf.item()), 2),
                    "latex": "",
                }
                formula_list.append(new_item)
87
                bbox_img = np_array_image[ymin:ymax, xmin:xmax]
88
89
90
91
                area = (xmax - xmin) * (ymax - ymin)

                curr_idx = len(mf_image_list)
                image_info.append((area, curr_idx, bbox_img))
92
93
94
95
                mf_image_list.append(bbox_img)

            images_formula_list.append(formula_list)
            backfill_list += formula_list
96

97
98
99
100
101
102
103
104
105
        # Stable sort by area
        image_info.sort(key=lambda x: x[0])  # sort by area
        sorted_indices = [x[1] for x in image_info]
        sorted_images = [x[2] for x in image_info]

        # Create mapping for results
        index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}

        # Create dataset with sorted images
106
        dataset = MathDataset(sorted_images, transform=self.model.transform)
107
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
108
109

        # Process batches and store results
110
111
        mfr_res = []
        for mf_img in dataloader:
112
            mf_img = mf_img.to(dtype=self.model.dtype)
113
114
115
            mf_img = mf_img.to(self.device)
            with torch.no_grad():
                output = self.model.generate({"image": mf_img})
116
            mfr_res.extend(output["fixed_str"])
117
118
119
120
121

        # Restore original order
        unsorted_results = [""] * len(mfr_res)
        for new_idx, latex in enumerate(mfr_res):
            original_idx = index_mapping[new_idx]
122
            unsorted_results[original_idx] = latex
123
124
125
126
127

        # Fill results back
        for res, latex in zip(backfill_list, unsorted_results):
            res["latex"] = latex

128
        return images_formula_list