Commit 553f250f authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): optimize code and improve logging

- Remove unused imports and comments
- Increase MIN_BATCH_INFERENCE_SIZE from 100 to 200
- Comment out VRAM cleaning and logging in batch_analyze.py
- Simplify code in doc_analyze_by_custom_model.py- Add tqdm progress bar in pdf_parse_union_core_v2.py
- Enable tqdm in OCR processing
parent 86058278
import time
import cv2
import torch
from loguru import logger
from tqdm import tqdm
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_table_recog_config
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1
......@@ -86,7 +82,7 @@ class BatchAnalyze:
# )
# 清理显存
clean_vram(self.model.device, vram_threshold=8)
# clean_vram(self.model.device, vram_threshold=8)
ocr_res_list_all_page = []
table_res_list_all_page = []
......
......@@ -188,7 +188,7 @@ def batch_doc_analyze(
formula_enable=None,
table_enable=None,
):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_size = MIN_BATCH_INFERENCE_SIZE
images = []
page_wh_list = []
......@@ -245,8 +245,7 @@ def may_batch_image_analyze(
model_manager = ModelSingleton()
images = [image for image, _, _ in images_with_extra_info]
batch_analyze = False
# images = [image for image, _, _ in images_with_extra_info]
batch_ratio = 1
device = get_device()
......@@ -269,25 +268,22 @@ def may_batch_image_analyze(
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
# batch_analyze = True
elif str(device).startswith('mps'):
# batch_analyze = True
pass
doc_analyze_start = time.time()
# doc_analyze_start = time.time()
batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images_with_extra_info)
gc_start = time.time()
# gc_start = time.time()
clean_memory(get_device())
gc_time = round(time.time() - gc_start, 2)
logger.info(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
logger.info(
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second'
)
return (idx, results)
# gc_time = round(time.time() - gc_start, 2)
# logger.debug(f'gc time: {gc_time}')
# doc_analyze_time = round(time.time() - doc_analyze_start, 2)
# doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
# logger.debug(
# f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
# f' speed: {doc_analyze_speed} pages/second'
# )
return idx, results
......@@ -12,6 +12,7 @@ import fitz
import torch
import numpy as np
from loguru import logger
from tqdm import tqdm
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.ocr_content_type import BlockType, ContentType
......@@ -932,17 +933,18 @@ def pdf_parse_union(
logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(dataset) - 1
"""初始化启动时间"""
start_time = time.time()
# """初始化启动时间"""
# start_time = time.time()
for page_id, page in enumerate(dataset):
"""debug时输出每页解析的耗时."""
if debug_mode:
time_now = time.time()
logger.info(
f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
)
start_time = time_now
# for page_id, page in enumerate(dataset):
for page_id, page in tqdm(enumerate(dataset), total=len(dataset), desc="Processing pages"):
# """debug时输出每页解析的耗时."""
# if debug_mode:
# time_now = time.time()
# logger.info(
# f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
# )
# start_time = time_now
"""解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id:
......@@ -988,7 +990,7 @@ def pdf_parse_union(
lang=lang
)
rec_start = time.time()
ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts
assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
# Process OCR results for this language
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment