"vscode:/vscode.git/clone" did not exist on "74fc41c7b4a17cff3b75729a178ed35270895acf"
Unverified Commit a01bd7ed authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1821 from myhloli/dev

perf(mfr): improve Math Formula Recognition by sorting images by area
parents 058c349c 58b6ad8c
import time
import functools
from collections import defaultdict
from typing import Dict, List
class PerformanceStats:
"""性能统计类,用于收集和展示方法执行时间"""
_stats: Dict[str, List[float]] = defaultdict(list)
@classmethod
def add_execution_time(cls, func_name: str, execution_time: float):
"""添加执行时间记录"""
cls._stats[func_name].append(execution_time)
@classmethod
def get_stats(cls) -> Dict[str, dict]:
"""获取统计结果"""
results = {}
for func_name, times in cls._stats.items():
results[func_name] = {
'count': len(times),
'total_time': sum(times),
'avg_time': sum(times) / len(times),
'min_time': min(times),
'max_time': max(times)
}
return results
@classmethod
def print_stats(cls):
"""打印统计结果"""
stats = cls.get_stats()
print("\n性能统计结果:")
print("-" * 80)
print(f"{'方法名':<40} {'调用次数':>8} {'总时间(s)':>12} {'平均时间(s)':>12}")
print("-" * 80)
for func_name, data in stats.items():
print(f"{func_name:<40} {data['count']:8d} {data['total_time']:12.6f} {data['avg_time']:12.6f}")
def measure_time(func):
"""测量方法执行时间的装饰器"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
execution_time = time.time() - start_time
PerformanceStats.add_execution_time(func.__name__, execution_time)
return result
return wrapper
\ No newline at end of file
......@@ -170,13 +170,7 @@ def doc_analyze(
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
if gpu_memory is not None and gpu_memory >= 8:
if gpu_memory >= 40:
batch_ratio = 32
elif gpu_memory >=20:
batch_ratio = 16
elif gpu_memory >= 16:
batch_ratio = 8
elif gpu_memory >= 10:
if gpu_memory >= 10:
batch_ratio = 4
else:
batch_ratio = 2
......
......@@ -100,20 +100,61 @@ class UnimernetModel(object):
res["latex"] = latex_rm_whitespace(latex)
return formula_list
def batch_predict(
self, images_mfd_res: list, images: list, batch_size: int = 64
) -> list:
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
images_formula_list = []
mf_image_list = []
backfill_list = []
image_info = [] # Store (area, original_index, image) tuples
# Collect images with their original indices
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index])
formula_list = []
for xyxy, conf, cla in zip(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
):
for idx, (xyxy, conf, cla) in enumerate(zip(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
)):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": 13 + int(cla.item()),
......@@ -123,19 +164,43 @@ class UnimernetModel(object):
}
formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list)
image_info.append((area, curr_idx, bbox_img))
mf_image_list.append(bbox_img)
images_formula_list.append(formula_list)
backfill_list += formula_list
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# Stable sort by area
image_info.sort(key=lambda x: x[0]) # sort by area
sorted_indices = [x[1] for x in image_info]
sorted_images = [x[2] for x in image_info]
# Create mapping for results
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
# Create dataset with sorted images
dataset = MathDataset(sorted_images, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# Process batches and store results
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"])
for res, latex in zip(backfill_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex)
# Restore original order
unsorted_results = [""] * len(mfr_res)
for new_idx, latex in enumerate(mfr_res):
original_idx = index_mapping[new_idx]
unsorted_results[original_idx] = latex_rm_whitespace(latex)
# Fill results back
for res, latex in zip(backfill_list, unsorted_results):
res["latex"] = latex
return images_formula_list
......@@ -21,9 +21,12 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.libs.performance_stats import measure_time, PerformanceStats
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
from concurrent.futures import ThreadPoolExecutor
try:
import torchtext
......@@ -215,7 +218,7 @@ def calculate_contrast(img, img_mode) -> float:
# logger.info(f"contrast: {contrast}")
return round(contrast, 2)
# @measure_time
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
# cid用0xfffd表示,连字符拆开
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
......@@ -489,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
else:
return [[x0, y0, x1, y1]]
# @measure_time
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = []
......@@ -923,7 +926,6 @@ def pdf_parse_union(
magic_model = MagicModel(model_list, dataset)
"""根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
......@@ -960,6 +962,8 @@ def pdf_parse_union(
)
pdf_info_dict[f'page_{page_id}'] = page_info
# PerformanceStats.print_stats()
"""分段"""
para_split(pdf_info_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment