Commit 59fc80d4 authored by myhloli's avatar myhloli
Browse files

perf(mfr): improve Math Formula Recognition by sorting images by area

- Sort detected images by area before processing to enhance MFR accuracy
- Implement stable sorting to maintain original order of images with equal
parent 6bfc1711
...@@ -100,20 +100,61 @@ class UnimernetModel(object): ...@@ -100,20 +100,61 @@ class UnimernetModel(object):
res["latex"] = latex_rm_whitespace(latex) res["latex"] = latex_rm_whitespace(latex)
return formula_list return formula_list
def batch_predict( # def batch_predict(
self, images_mfd_res: list, images: list, batch_size: int = 64 # self, images_mfd_res: list, images: list, batch_size: int = 64
) -> list: # ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
images_formula_list = [] images_formula_list = []
mf_image_list = [] mf_image_list = []
backfill_list = [] backfill_list = []
image_info = [] # Store (area, original_index, image) tuples
# Collect images with their original indices
for image_index in range(len(images_mfd_res)): for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index] mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index]) pil_img = Image.fromarray(images[image_index])
formula_list = [] formula_list = []
for xyxy, conf, cla in zip( for idx, (xyxy, conf, cla) in enumerate(zip(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
): )):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = { new_item = {
"category_id": 13 + int(cla.item()), "category_id": 13 + int(cla.item()),
...@@ -123,19 +164,43 @@ class UnimernetModel(object): ...@@ -123,19 +164,43 @@ class UnimernetModel(object):
} }
formula_list.append(new_item) formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list)
image_info.append((area, curr_idx, bbox_img))
mf_image_list.append(bbox_img) mf_image_list.append(bbox_img)
images_formula_list.append(formula_list) images_formula_list.append(formula_list)
backfill_list += formula_list backfill_list += formula_list
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) # Stable sort by area
image_info.sort(key=lambda x: x[0]) # sort by area
sorted_indices = [x[1] for x in image_info]
sorted_images = [x[2] for x in image_info]
# Create mapping for results
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
# Create dataset with sorted images
dataset = MathDataset(sorted_images, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# Process batches and store results
mfr_res = [] mfr_res = []
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({"image": mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"]) mfr_res.extend(output["pred_str"])
for res, latex in zip(backfill_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex) # Restore original order
unsorted_results = [""] * len(mfr_res)
for new_idx, latex in enumerate(mfr_res):
original_idx = index_mapping[new_idx]
unsorted_results[original_idx] = latex_rm_whitespace(latex)
# Fill results back
for res, latex in zip(backfill_list, unsorted_results):
res["latex"] = latex
return images_formula_list return images_formula_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment