Commit 553f250f authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): optimize code and improve logging

- Remove unused imports and comments
- Increase MIN_BATCH_INFERENCE_SIZE from 100 to 200
- Comment out VRAM cleaning and logging in batch_analyze.py
- Simplify code in doc_analyze_by_custom_model.py- Add tqdm progress bar in pdf_parse_union_core_v2.py
- Enable tqdm in OCR processing
parent 86058278
import time import time
import cv2 import cv2
import torch
from loguru import logger from loguru import logger
from tqdm import tqdm from tqdm import tqdm
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.libs.config_reader import get_table_recog_config
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
YOLO_LAYOUT_BASE_BATCH_SIZE = 1 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1 MFD_BASE_BATCH_SIZE = 1
...@@ -86,7 +82,7 @@ class BatchAnalyze: ...@@ -86,7 +82,7 @@ class BatchAnalyze:
# ) # )
# 清理显存 # 清理显存
clean_vram(self.model.device, vram_threshold=8) # clean_vram(self.model.device, vram_threshold=8)
ocr_res_list_all_page = [] ocr_res_list_all_page = []
table_res_list_all_page = [] table_res_list_all_page = []
......
...@@ -188,7 +188,7 @@ def batch_doc_analyze( ...@@ -188,7 +188,7 @@ def batch_doc_analyze(
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
): ):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100)) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_size = MIN_BATCH_INFERENCE_SIZE batch_size = MIN_BATCH_INFERENCE_SIZE
images = [] images = []
page_wh_list = [] page_wh_list = []
...@@ -245,8 +245,7 @@ def may_batch_image_analyze( ...@@ -245,8 +245,7 @@ def may_batch_image_analyze(
model_manager = ModelSingleton() model_manager = ModelSingleton()
images = [image for image, _, _ in images_with_extra_info] # images = [image for image, _, _ in images_with_extra_info]
batch_analyze = False
batch_ratio = 1 batch_ratio = 1
device = get_device() device = get_device()
...@@ -269,25 +268,22 @@ def may_batch_image_analyze( ...@@ -269,25 +268,22 @@ def may_batch_image_analyze(
else: else:
batch_ratio = 1 batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
# batch_analyze = True
elif str(device).startswith('mps'):
# batch_analyze = True
pass
doc_analyze_start = time.time()
# doc_analyze_start = time.time()
batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable) batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images_with_extra_info) results = batch_model(images_with_extra_info)
gc_start = time.time() # gc_start = time.time()
clean_memory(get_device()) clean_memory(get_device())
gc_time = round(time.time() - gc_start, 2) # gc_time = round(time.time() - gc_start, 2)
logger.info(f'gc time: {gc_time}') # logger.debug(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2) # doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round(len(images) / doc_analyze_time, 2) # doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
logger.info( # logger.debug(
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},' # f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second' # f' speed: {doc_analyze_speed} pages/second'
) # )
return (idx, results) return idx, results
...@@ -12,6 +12,7 @@ import fitz ...@@ -12,6 +12,7 @@ import fitz
import torch import torch
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from tqdm import tqdm
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.ocr_content_type import BlockType, ContentType from magic_pdf.config.ocr_content_type import BlockType, ContentType
...@@ -932,17 +933,18 @@ def pdf_parse_union( ...@@ -932,17 +933,18 @@ def pdf_parse_union(
logger.warning('end_page_id is out of range, use pdf_docs length') logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(dataset) - 1 end_page_id = len(dataset) - 1
"""初始化启动时间""" # """初始化启动时间"""
start_time = time.time() # start_time = time.time()
for page_id, page in enumerate(dataset): # for page_id, page in enumerate(dataset):
"""debug时输出每页解析的耗时.""" for page_id, page in tqdm(enumerate(dataset), total=len(dataset), desc="Processing pages"):
if debug_mode: # """debug时输出每页解析的耗时."""
time_now = time.time() # if debug_mode:
logger.info( # time_now = time.time()
f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}' # logger.info(
) # f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
start_time = time_now # )
# start_time = time_now
"""解析pdf中的每一页""" """解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
...@@ -988,7 +990,7 @@ def pdf_parse_union( ...@@ -988,7 +990,7 @@ def pdf_parse_union(
lang=lang lang=lang
) )
rec_start = time.time() rec_start = time.time()
ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0] ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts # Verify we have matching counts
assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}' assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
# Process OCR results for this language # Process OCR results for this language
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment