Unverified Commit fe4e62a7 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2077 from myhloli/dev

feat(model): add tqdm progress bar to model prediction loops
parents 09bd890e 1fd72f5f
...@@ -307,7 +307,7 @@ You can modify certain configurations in this file to enable or disable features ...@@ -307,7 +307,7 @@ You can modify certain configurations in this file to enable or disable features
}, },
"table-config": { "table-config": {
"model": "rapid_table", "model": "rapid_table",
"sub_model": "slanet_plus", // When the model is "rapid_table", you can choose a sub_model. The options are "slanet_plus" and "unitable" "sub_model": "slanet_plus",
"enable": true, // The table recognition feature is enabled by default. If you need to disable it, please change the value here to "false". "enable": true, // The table recognition feature is enabled by default. If you need to disable it, please change the value here to "false".
"max_time": 400 "max_time": 400
} }
......
...@@ -310,8 +310,8 @@ pip install -U "magic-pdf[full]" -i https://mirrors.aliyun.com/pypi/simple ...@@ -310,8 +310,8 @@ pip install -U "magic-pdf[full]" -i https://mirrors.aliyun.com/pypi/simple
"enable": true // 公式识别功能默认是开启的,如果需要关闭请修改此处的值为"false" "enable": true // 公式识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
}, },
"table-config": { "table-config": {
"model": "rapid_table", "model": "rapid_table",
"sub_model": "slanet_plus", // 当model为"rapid_table"时,可以自选sub_model,可选项为"slanet_plus""unitable" "sub_model": "slanet_plus",
"enable": true, // 表格识别功能默认是开启的,如果需要关闭请修改此处的值为"false" "enable": true, // 表格识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
"max_time": 400 "max_time": 400
} }
......
...@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1 ...@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1
ftfy ftfy
openai openai
pydantic>=2.7.2,<2.11 pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0 transformers>=4.49.0,<5.0.0
\ No newline at end of file tqdm>=4.67.1
\ No newline at end of file
...@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1 ...@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1
ftfy ftfy
openai openai
pydantic>=2.7.2,<2.11 pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0 transformers>=4.49.0,<5.0.0
\ No newline at end of file tqdm>=4.67.1
\ No newline at end of file
...@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1 ...@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1
ftfy ftfy
openai openai
pydantic>=2.7.2,<2.11 pydantic>=2.7.2,<2.11
transformers>=4.49.0,<5.0.0 transformers>=4.49.0,<5.0.0
\ No newline at end of file tqdm>=4.67.1
\ No newline at end of file
import time import time
import cv2 import cv2
import torch
from loguru import logger from loguru import logger
from tqdm import tqdm
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
...@@ -52,9 +51,9 @@ class BatchAnalyze: ...@@ -52,9 +51,9 @@ class BatchAnalyze:
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
) )
logger.info( # logger.info(
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}' # f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
) # )
if self.model.apply_formula: if self.model.apply_formula:
# 公式检测 # 公式检测
...@@ -63,9 +62,9 @@ class BatchAnalyze: ...@@ -63,9 +62,9 @@ class BatchAnalyze:
# images, self.batch_ratio * MFD_BASE_BATCH_SIZE # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
images, MFD_BASE_BATCH_SIZE images, MFD_BASE_BATCH_SIZE
) )
logger.info( # logger.info(
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}' # f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
) # )
# 公式识别 # 公式识别
mfr_start_time = time.time() mfr_start_time = time.time()
...@@ -78,104 +77,117 @@ class BatchAnalyze: ...@@ -78,104 +77,117 @@ class BatchAnalyze:
for image_index in range(len(images)): for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index] images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index]) mfr_count += len(images_formula_list[image_index])
logger.info( # logger.info(
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}' # f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
) # )
# 清理显存 # 清理显存
clean_vram(self.model.device, vram_threshold=8) # clean_vram(self.model.device, vram_threshold=8)
det_time = 0 ocr_res_list_all_page = []
det_count = 0 table_res_list_all_page = []
table_time = 0
table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)): for index in range(len(images)):
_, ocr_enable, _lang = images_with_extra_info[index] _, ocr_enable, _lang = images_with_extra_info[index]
self.model = self.model_manager.get_model(ocr_enable, self.show_log, _lang, self.layout_model, self.formula_enable, self.table_enable)
layout_res = images_layout_res[index] layout_res = images_layout_res[index]
np_array_img = images[index] np_array_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = ( ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res) get_res_list_from_layout_res(layout_res)
) )
# ocr识别
det_start = time.time() ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
'lang':_lang,
'ocr_enable':ocr_enable,
'np_array_img':np_array_img,
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
for table_res in table_res_list:
table_img, _ = crop_img(table_res, np_array_img)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':table_img,
})
# 文本框检测
det_start = time.time()
det_count = 0
# for ocr_res_list_dict in ocr_res_list_all_page:
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: _lang = ocr_res_list_dict['lang']
# Get OCR results for this language's images
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=_lang
)
for res in ocr_res_list_dict['ocr_res_list']:
new_image, useful_list = crop_img( new_image, useful_list = crop_img(
res, np_array_img, crop_paste_x=50, crop_paste_y=50 res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
) )
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res( adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
) )
# OCR recognition # OCR-det
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
ocr_res = ocr_model.ocr(
# if ocr_enable:
# ocr_res = self.model.ocr_model.ocr(
# new_image, mfd_res=adjusted_mfdetrec_res
# )[0]
# else:
ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res, rec=False new_image, mfd_res=adjusted_mfdetrec_res, rec=False
)[0] )[0]
# Integration results # Integration results
if ocr_res: if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, _lang) ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
layout_res.extend(ocr_result_list) ocr_res_list_dict['layout_res'].extend(ocr_result_list)
det_time += time.time() - det_start det_count += len(ocr_res_list_dict['ocr_res_list'])
det_count += len(ocr_res_list) # logger.info(f'ocr-det time: {round(time.time()-det_start, 2)}, image num: {det_count}')
# 表格识别 table recognition
if self.model.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, np_array_img)
single_table_start_time = time.time()
html_code = None
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.model.table_model.predict(
new_image, 'html'
)
if len(table_result) > 0:
html_code = table_result[0]
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.model.table_model.img2html(new_image)
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, logic_points, elapse = (
self.model.table_model.predict(new_image)
)
run_time = time.time() - single_table_start_time
if run_time > self.model.table_max_time:
logger.warning(
f'table recognition processing exceeds max time {self.model.table_max_time}s'
)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
res['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
'table recognition processing fails, not get html return'
)
table_time += time.time() - table_start
table_count += len(table_res_list)
logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}') # 表格识别 table recognition
if self.model.apply_table: if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}') table_start = time.time()
table_count = 0
# for table_res_list_dict in table_res_list_all_page:
for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_dict['lang']
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=_lang
)
table_model = atom_model_manager.get_atom_model(
atom_model_name='table',
table_model_name='rapid_table',
table_model_path='',
table_max_time=400,
device='cpu',
ocr_engine=ocr_engine,
table_sub_model_name='slanet_plus'
)
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
table_res_dict['table_res']['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
'table recognition processing fails, not get html return'
)
# logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
# Create dictionaries to store items by language # Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language need_ocr_lists_by_lang = {} # Dict of lists for each language
...@@ -219,7 +231,7 @@ class BatchAnalyze: ...@@ -219,7 +231,7 @@ class BatchAnalyze:
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=lang lang=lang
) )
ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0] ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts # Verify we have matching counts
assert len(ocr_res_list) == len( assert len(ocr_res_list) == len(
...@@ -234,7 +246,7 @@ class BatchAnalyze: ...@@ -234,7 +246,7 @@ class BatchAnalyze:
total_processed += len(img_crop_list) total_processed += len(img_crop_list)
rec_time += time.time() - rec_start rec_time += time.time() - rec_start
logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}') # logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
......
...@@ -188,7 +188,7 @@ def batch_doc_analyze( ...@@ -188,7 +188,7 @@ def batch_doc_analyze(
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
): ):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100)) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
batch_size = MIN_BATCH_INFERENCE_SIZE batch_size = MIN_BATCH_INFERENCE_SIZE
images = [] images = []
page_wh_list = [] page_wh_list = []
...@@ -245,8 +245,7 @@ def may_batch_image_analyze( ...@@ -245,8 +245,7 @@ def may_batch_image_analyze(
model_manager = ModelSingleton() model_manager = ModelSingleton()
images = [image for image, _, _ in images_with_extra_info] # images = [image for image, _, _ in images_with_extra_info]
batch_analyze = False
batch_ratio = 1 batch_ratio = 1
device = get_device() device = get_device()
...@@ -269,25 +268,22 @@ def may_batch_image_analyze( ...@@ -269,25 +268,22 @@ def may_batch_image_analyze(
else: else:
batch_ratio = 1 batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
# batch_analyze = True
elif str(device).startswith('mps'):
# batch_analyze = True
pass
doc_analyze_start = time.time()
# doc_analyze_start = time.time()
batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable) batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images_with_extra_info) results = batch_model(images_with_extra_info)
gc_start = time.time() # gc_start = time.time()
clean_memory(get_device()) clean_memory(get_device())
gc_time = round(time.time() - gc_start, 2) # gc_time = round(time.time() - gc_start, 2)
logger.info(f'gc time: {gc_time}') # logger.debug(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2) # doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round(len(images) / doc_analyze_time, 2) # doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
logger.info( # logger.debug(
f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},' # f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second' # f' speed: {doc_analyze_speed} pages/second'
) # )
return (idx, results) return idx, results
from doclayout_yolo import YOLOv10 from doclayout_yolo import YOLOv10
from tqdm import tqdm
class DocLayoutYOLOModel(object): class DocLayoutYOLOModel(object):
...@@ -31,7 +32,8 @@ class DocLayoutYOLOModel(object): ...@@ -31,7 +32,8 @@ class DocLayoutYOLOModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = [] images_layout_res = []
for index in range(0, len(images), batch_size): # for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
doclayout_yolo_res = [ doclayout_yolo_res = [
image_res.cpu() image_res.cpu()
for image_res in self.model.predict( for image_res in self.model.predict(
......
from tqdm import tqdm
from ultralytics import YOLO from ultralytics import YOLO
...@@ -14,7 +15,8 @@ class YOLOv8MFDModel(object): ...@@ -14,7 +15,8 @@ class YOLOv8MFDModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = [] images_mfd_res = []
for index in range(0, len(images), batch_size): # for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
mfd_res = [ mfd_res = [
image_res.cpu() image_res.cpu()
for image_res in self.mfd_model.predict( for image_res in self.mfd_model.predict(
......
import torch import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
class MathDataset(Dataset): class MathDataset(Dataset):
...@@ -107,12 +108,19 @@ class UnimernetModel(object): ...@@ -107,12 +108,19 @@ class UnimernetModel(object):
# Process batches and store results # Process batches and store results
mfr_res = [] mfr_res = []
for mf_img in dataloader: # for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device) with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
with torch.no_grad(): for index, mf_img in enumerate(dataloader):
output = self.model.generate({"image": mf_img}) mf_img = mf_img.to(dtype=self.model.dtype)
mfr_res.extend(output["fixed_str"]) mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["fixed_str"])
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
pbar.update(current_batch_size)
# Restore original order # Restore original order
unsorted_results = [""] * len(mfr_res) unsorted_results = [""] * len(mfr_res)
......
...@@ -72,6 +72,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -72,6 +72,7 @@ class PytorchPaddleOCR(TextSystem):
kwargs['det_model_path'] = os.path.join(ocr_models_dir, det) kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec) kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file) kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
# kwargs['rec_batch_num'] = 8
kwargs['device'] = get_device() kwargs['device'] = get_device()
...@@ -86,6 +87,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -86,6 +87,7 @@ class PytorchPaddleOCR(TextSystem):
det=True, det=True,
rec=True, rec=True,
mfd_res=None, mfd_res=None,
tqdm_enable=False,
): ):
assert isinstance(img, (np.ndarray, list, str, bytes)) assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
...@@ -129,7 +131,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -129,7 +131,7 @@ class PytorchPaddleOCR(TextSystem):
if not isinstance(img, list): if not isinstance(img, list):
img = preprocess_image(img) img = preprocess_image(img)
img = [img] img = [img]
rec_res, elapse = self.text_recognizer(img) rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable)
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse)) # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
ocr_res.append(rec_res) ocr_res.append(rec_res)
return ocr_res return ocr_res
......
...@@ -4,6 +4,8 @@ import numpy as np ...@@ -4,6 +4,8 @@ import numpy as np
import math import math
import time import time
import torch import torch
from tqdm import tqdm
from ...pytorchocr.base_ocr_v20 import BaseOCRV20 from ...pytorchocr.base_ocr_v20 import BaseOCRV20
from . import pytorchocr_utility as utility from . import pytorchocr_utility as utility
from ...pytorchocr.postprocess import build_post_process from ...pytorchocr.postprocess import build_post_process
...@@ -286,7 +288,7 @@ class TextRecognizer(BaseOCRV20): ...@@ -286,7 +288,7 @@ class TextRecognizer(BaseOCRV20):
return img return img
def __call__(self, img_list): def __call__(self, img_list, tqdm_enable=False):
img_num = len(img_list) img_num = len(img_list)
# Calculate the aspect ratio of all text bars # Calculate the aspect ratio of all text bars
width_list = [] width_list = []
...@@ -299,131 +301,140 @@ class TextRecognizer(BaseOCRV20): ...@@ -299,131 +301,140 @@ class TextRecognizer(BaseOCRV20):
rec_res = [['', 0.0]] * img_num rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num batch_num = self.rec_batch_num
elapse = 0 elapse = 0
for beg_img_no in range(0, img_num, batch_num): # for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
norm_img_batch = [] index = 0
max_wh_ratio = 0 for beg_img_no in range(0, img_num, batch_num):
for ino in range(beg_img_no, end_img_no): end_img_no = min(img_num, beg_img_no + batch_num)
# h, w = img_list[ino].shape[0:2] norm_img_batch = []
h, w = img_list[indices[ino]].shape[0:2] max_wh_ratio = 0
wh_ratio = w * 1.0 / h for ino in range(beg_img_no, end_img_no):
max_wh_ratio = max(max_wh_ratio, wh_ratio) # h, w = img_list[ino].shape[0:2]
for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2]
if self.rec_algorithm == "SAR": wh_ratio = w * 1.0 / h
norm_img, _, _, valid_ratio = self.resize_norm_img_sar( max_wh_ratio = max(max_wh_ratio, wh_ratio)
img_list[indices[ino]], self.rec_image_shape) for ino in range(beg_img_no, end_img_no):
norm_img = norm_img[np.newaxis, :] if self.rec_algorithm == "SAR":
valid_ratio = np.expand_dims(valid_ratio, axis=0) norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
valid_ratios = [] img_list[indices[ino]], self.rec_image_shape)
valid_ratios.append(valid_ratio) norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
elif self.rec_algorithm == "SVTR": valid_ratios.append(valid_ratio)
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], norm_img_batch.append(norm_img)
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :] elif self.rec_algorithm == "SVTR":
norm_img_batch.append(norm_img) norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
elif self.rec_algorithm == "SRN": self.rec_image_shape)
norm_img = self.process_image_srn(img_list[indices[ino]], norm_img = norm_img[np.newaxis, :]
self.rec_image_shape, 8, norm_img_batch.append(norm_img)
self.max_text_length) elif self.rec_algorithm == "SRN":
encoder_word_pos_list = [] norm_img = self.process_image_srn(img_list[indices[ino]],
gsrm_word_pos_list = [] self.rec_image_shape, 8,
gsrm_slf_attn_bias1_list = [] self.max_text_length)
gsrm_slf_attn_bias2_list = [] encoder_word_pos_list = []
encoder_word_pos_list.append(norm_img[1]) gsrm_word_pos_list = []
gsrm_word_pos_list.append(norm_img[2]) gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias1_list.append(norm_img[3]) gsrm_slf_attn_bias2_list = []
gsrm_slf_attn_bias2_list.append(norm_img[4]) encoder_word_pos_list.append(norm_img[1])
norm_img_batch.append(norm_img[0]) gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
elif self.rec_algorithm == "CAN":
norm_img = self.norm_img_can(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_image_mask = np.ones(norm_img.shape, dtype='float32')
word_label = np.ones([1, 36], dtype='int64')
norm_img_mask_batch = []
word_label_list = []
norm_img_mask_batch.append(norm_image_mask)
word_label_list.append(word_label)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
inp = inp.to(self.device)
encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
backbone_out = self.net.backbone(inp) # backbone_feat
prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
# preds = {"predict": prob_out[2]}
preds = {"predict": prob_out["predict"]}
elif self.rec_algorithm == "SAR":
starttime = time.time()
# valid_ratios = np.concatenate(valid_ratios)
# inputs = [
# norm_img_batch,
# valid_ratios,
# ]
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
preds = self.net(inp)
elif self.rec_algorithm == "CAN": elif self.rec_algorithm == "CAN":
norm_img = self.norm_img_can(img_list[indices[ino]], starttime = time.time()
max_wh_ratio) norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
norm_img = norm_img[np.newaxis, :] word_label_list = np.concatenate(word_label_list)
norm_img_batch.append(norm_img) inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
norm_image_mask = np.ones(norm_img.shape, dtype='float32')
word_label = np.ones([1, 36], dtype='int64')
norm_img_mask_batch = []
word_label_list = []
norm_img_mask_batch.append(norm_image_mask)
word_label_list.append(word_label)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
inp = inp.to(self.device)
encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
backbone_out = self.net.backbone(inp) # backbone_feat
prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
# preds = {"predict": prob_out[2]}
preds = {"predict": prob_out["predict"]}
elif self.rec_algorithm == "SAR":
starttime = time.time()
# valid_ratios = np.concatenate(valid_ratios)
# inputs = [
# norm_img_batch,
# valid_ratios,
# ]
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
preds = self.net(inp)
elif self.rec_algorithm == "CAN":
starttime = time.time()
norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
word_label_list = np.concatenate(word_label_list)
inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
inp = [torch.from_numpy(e_i) for e_i in inputs]
inp = [e_i.to(self.device) for e_i in inp]
with torch.no_grad():
outputs = self.net(inp)
outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
preds = outputs
else: inp = [torch.from_numpy(e_i) for e_i in inputs]
starttime = time.time() inp = [e_i.to(self.device) for e_i in inp]
with torch.no_grad():
outputs = self.net(inp)
outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
with torch.no_grad(): preds = outputs
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
prob_out = self.net(inp)
if isinstance(prob_out, list):
preds = [v.cpu().numpy() for v in prob_out]
else: else:
preds = prob_out.cpu().numpy() starttime = time.time()
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
prob_out = self.net(inp)
if isinstance(prob_out, list):
preds = [v.cpu().numpy() for v in prob_out]
else:
preds = prob_out.cpu().numpy()
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse += time.time() - starttime
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
current_batch_size = min(batch_num, img_num - index * batch_num)
index += 1
pbar.update(current_batch_size)
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse += time.time() - starttime
return rec_res, elapse return rec_res, elapse
...@@ -9,7 +9,7 @@ from magic_pdf.libs.config_reader import get_device ...@@ -9,7 +9,7 @@ from magic_pdf.libs.config_reader import get_device
class RapidTableModel(object): class RapidTableModel(object):
def __init__(self, ocr_engine, table_sub_model_name): def __init__(self, ocr_engine, table_sub_model_name='slanet_plus'):
sub_model_list = [model.value for model in ModelType] sub_model_list = [model.value for model in ModelType]
if table_sub_model_name is None: if table_sub_model_name is None:
input_args = RapidTableInput() input_args = RapidTableInput()
......
...@@ -12,6 +12,7 @@ import fitz ...@@ -12,6 +12,7 @@ import fitz
import torch import torch
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from tqdm import tqdm
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.ocr_content_type import BlockType, ContentType from magic_pdf.config.ocr_content_type import BlockType, ContentType
...@@ -932,17 +933,18 @@ def pdf_parse_union( ...@@ -932,17 +933,18 @@ def pdf_parse_union(
logger.warning('end_page_id is out of range, use pdf_docs length') logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(dataset) - 1 end_page_id = len(dataset) - 1
"""初始化启动时间""" # """初始化启动时间"""
start_time = time.time() # start_time = time.time()
for page_id, page in enumerate(dataset): # for page_id, page in enumerate(dataset):
"""debug时输出每页解析的耗时.""" for page_id, page in tqdm(enumerate(dataset), total=len(dataset), desc="Processing pages"):
if debug_mode: # """debug时输出每页解析的耗时."""
time_now = time.time() # if debug_mode:
logger.info( # time_now = time.time()
f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}' # logger.info(
) # f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
start_time = time_now # )
# start_time = time_now
"""解析pdf中的每一页""" """解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
...@@ -987,8 +989,8 @@ def pdf_parse_union( ...@@ -987,8 +989,8 @@ def pdf_parse_union(
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=lang lang=lang
) )
rec_start = time.time() # rec_start = time.time()
ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0] ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
# Verify we have matching counts # Verify we have matching counts
assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}' assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
# Process OCR results for this language # Process OCR results for this language
...@@ -996,8 +998,8 @@ def pdf_parse_union( ...@@ -996,8 +998,8 @@ def pdf_parse_union(
ocr_text, ocr_score = ocr_res_list[index] ocr_text, ocr_score = ocr_res_list[index]
span['content'] = ocr_text span['content'] = ocr_text
span['score'] = float(round(ocr_score, 2)) span['score'] = float(round(ocr_score, 2))
rec_time = time.time() - rec_start # rec_time = time.time() - rec_start
logger.info(f'ocr-dynamic-rec time: {round(rec_time, 2)}, total images processed: {len(img_crop_list)}') # logger.info(f'ocr-dynamic-rec time: {round(rec_time, 2)}, total images processed: {len(img_crop_list)}')
"""分段""" """分段"""
......
...@@ -11,4 +11,5 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0 ...@@ -11,4 +11,5 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision torchvision
transformers>=4.49.0,<5.0.0 transformers>=4.49.0,<5.0.0
pdfminer.six==20231228 pdfminer.six==20231228
tqdm>=4.67.1
# The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator. # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment