Commit 8e1c2339 authored by myhloli's avatar myhloli
Browse files

feat(model): add tqdm progress bar to model prediction loops

- Add tqdm progress bar to batch prediction loops in multiple model modules
- Improve logging and error handling in batch analysis script
- Update table model initialization to use default sub-model if none specified
- Add tqdm dependency to requirements.txt
parent ddfeea94
......@@ -3,13 +3,16 @@ import time
import cv2
import torch
from loguru import logger
from tqdm import tqdm
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_table_recog_config
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1
......@@ -52,9 +55,9 @@ class BatchAnalyze:
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
logger.info(
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
)
# logger.info(
# f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
# )
if self.model.apply_formula:
# 公式检测
......@@ -63,9 +66,9 @@ class BatchAnalyze:
# images, self.batch_ratio * MFD_BASE_BATCH_SIZE
images, MFD_BASE_BATCH_SIZE
)
logger.info(
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
)
# logger.info(
# f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
# )
# 公式识别
mfr_start_time = time.time()
......@@ -78,82 +81,100 @@ class BatchAnalyze:
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
logger.info(
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
)
# logger.info(
# f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
# )
# 清理显存
clean_vram(self.model.device, vram_threshold=8)
det_time = 0
det_count = 0
table_time = 0
table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
ocr_res_list_all_page = []
table_res_list_all_page = []
for index in range(len(images)):
_, ocr_enable, _lang = images_with_extra_info[index]
self.model = self.model_manager.get_model(ocr_enable, self.show_log, _lang, self.layout_model, self.formula_enable, self.table_enable)
layout_res = images_layout_res[index]
np_array_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别
ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
'lang':_lang,
'ocr_enable':ocr_enable,
'np_array_img':np_array_img,
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
table_res_list_all_page.append({'table_res_list':table_res_list,
'lang':_lang,
'np_array_img':np_array_img,
})
# 文本框检测
det_start = time.time()
det_count = 0
# for ocr_res_list_dict in ocr_res_list_all_page:
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing
for res in ocr_res_list:
_lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img(
res, np_array_img, crop_paste_x=50, crop_paste_y=50
res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list
ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
)
# OCR recognition
# OCR-det
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
# if ocr_enable:
# ocr_res = self.model.ocr_model.ocr(
# new_image, mfd_res=adjusted_mfdetrec_res
# )[0]
# else:
ocr_res = self.model.ocr_model.ocr(
ocr_res = ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, _lang)
layout_res.extend(ocr_result_list)
det_time += time.time() - det_start
det_count += len(ocr_res_list)
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
ocr_res_list_dict['layout_res'].extend(ocr_result_list)
det_count += len(ocr_res_list_dict['ocr_res_list'])
# logger.info(f'ocr-det time: {round(time.time()-det_start, 2)}, image num: {det_count}')
# 表格识别 table recognition
if self.model.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, np_array_img)
single_table_start_time = time.time()
html_code = None
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.model.table_model.predict(
new_image, 'html'
)
if len(table_result) > 0:
html_code = table_result[0]
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.model.table_model.img2html(new_image)
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, logic_points, elapse = (
self.model.table_model.predict(new_image)
table_count = 0
# for table_res_list_dict in table_res_list_all_page:
for table_res_list_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_list_dict['lang']
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=_lang
)
run_time = time.time() - single_table_start_time
if run_time > self.model.table_max_time:
logger.warning(
f'table recognition processing exceeds max time {self.model.table_max_time}s'
table_model = atom_model_manager.get_atom_model(
atom_model_name='table',
table_model_name='rapid_table',
table_model_path='',
table_max_time=400,
device='cpu',
ocr_engine=ocr_engine,
table_sub_model_name='slanet_plus'
)
for res in table_res_list_dict['table_res_list']:
new_image, _ = crop_img(res, table_res_list_dict['np_array_img'])
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(new_image)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
......@@ -169,13 +190,8 @@ class BatchAnalyze:
logger.warning(
'table recognition processing fails, not get html return'
)
table_time += time.time() - table_start
table_count += len(table_res_list)
logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}')
if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
table_count += len(table_res_list_dict['table_res_list'])
# logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {table_count}')
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
......@@ -219,7 +235,7 @@ class BatchAnalyze:
det_db_box_thresh=0.3,
lang=lang
)
ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(
......@@ -234,7 +250,7 @@ class BatchAnalyze:
total_processed += len(img_crop_list)
rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
# logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
......
from doclayout_yolo import YOLOv10
from tqdm import tqdm
class DocLayoutYOLOModel(object):
......@@ -31,7 +32,8 @@ class DocLayoutYOLOModel(object):
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
for index in range(0, len(images), batch_size):
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), total=len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0), desc="Layout Predict"):
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
......
from tqdm import tqdm
from ultralytics import YOLO
......@@ -14,7 +15,10 @@ class YOLOv8MFDModel(object):
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
for index in range(0, len(images), batch_size):
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size),
total=len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0),
desc="MFD Predict"):
mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
......
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
class MathDataset(Dataset):
......@@ -107,7 +108,8 @@ class UnimernetModel(object):
# Process batches and store results
mfr_res = []
for mf_img in dataloader:
# for mf_img in dataloader:
for mf_img in tqdm(dataloader, desc="MFR Predict"):
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device)
with torch.no_grad():
......
......@@ -86,6 +86,7 @@ class PytorchPaddleOCR(TextSystem):
det=True,
rec=True,
mfd_res=None,
tqdm_enable=False,
):
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
......@@ -129,7 +130,7 @@ class PytorchPaddleOCR(TextSystem):
if not isinstance(img, list):
img = preprocess_image(img)
img = [img]
rec_res, elapse = self.text_recognizer(img)
rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable)
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
ocr_res.append(rec_res)
return ocr_res
......
......@@ -4,6 +4,8 @@ import numpy as np
import math
import time
import torch
from tqdm import tqdm
from ...pytorchocr.base_ocr_v20 import BaseOCRV20
from . import pytorchocr_utility as utility
from ...pytorchocr.postprocess import build_post_process
......@@ -286,7 +288,7 @@ class TextRecognizer(BaseOCRV20):
return img
def __call__(self, img_list):
def __call__(self, img_list, tqdm_enable=False):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
......@@ -299,7 +301,8 @@ class TextRecognizer(BaseOCRV20):
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
# for beg_img_no in range(0, img_num, batch_num):
for beg_img_no in tqdm(range(0, img_num, batch_num), desc='OCR-rec Predict', disable=not tqdm_enable):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
......
......@@ -9,7 +9,7 @@ from magic_pdf.libs.config_reader import get_device
class RapidTableModel(object):
def __init__(self, ocr_engine, table_sub_model_name):
def __init__(self, ocr_engine, table_sub_model_name='slanet_plus'):
sub_model_list = [model.value for model in ModelType]
if table_sub_model_name is None:
input_args = RapidTableInput()
......
......@@ -11,4 +11,5 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision
transformers>=4.49.0,<5.0.0
pdfminer.six==20231228
tqdm>=4.67.1
# The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment