Commit 1fd72f5f authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): optimize table recognition and layout detection

- Update table recognition logic to process each table individually
- Refactor layout detection to use tqdm for progress tracking
- Optimize OCR recognition by using a single tqdm wrapper
- Improve MFR prediction with a more accurate progress bar
- Simplify MFD prediction by removing unnecessary total calculation
parent 795233d1
...@@ -102,9 +102,12 @@ class BatchAnalyze: ...@@ -102,9 +102,12 @@ class BatchAnalyze:
'single_page_mfdetrec_res':single_page_mfdetrec_res, 'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res, 'layout_res':layout_res,
}) })
table_res_list_all_page.append({'table_res_list':table_res_list,
for table_res in table_res_list:
table_img, _ = crop_img(table_res, np_array_img)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang, 'lang':_lang,
'np_array_img':np_array_img, 'table_img':table_img,
}) })
# 文本框检测 # 文本框检测
...@@ -149,8 +152,8 @@ class BatchAnalyze: ...@@ -149,8 +152,8 @@ class BatchAnalyze:
table_start = time.time() table_start = time.time()
table_count = 0 table_count = 0
# for table_res_list_dict in table_res_list_all_page: # for table_res_list_dict in table_res_list_all_page:
for table_res_list_dict in tqdm(table_res_list_all_page, desc="Table Predict"): for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_list_dict['lang'] _lang = table_res_dict['lang']
atom_model_manager = AtomModelSingleton() atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model( ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr', atom_model_name='ocr',
...@@ -168,16 +171,14 @@ class BatchAnalyze: ...@@ -168,16 +171,14 @@ class BatchAnalyze:
ocr_engine=ocr_engine, ocr_engine=ocr_engine,
table_sub_model_name='slanet_plus' table_sub_model_name='slanet_plus'
) )
for res in table_res_list_dict['table_res_list']: html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
new_image, _ = crop_img(res, table_res_list_dict['np_array_img'])
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(new_image)
# 判断是否返回正常 # 判断是否返回正常
if html_code: if html_code:
expected_ending = html_code.strip().endswith( expected_ending = html_code.strip().endswith(
'</html>' '</html>'
) or html_code.strip().endswith('</table>') ) or html_code.strip().endswith('</table>')
if expected_ending: if expected_ending:
res['html'] = html_code table_res_dict['table_res']['html'] = html_code
else: else:
logger.warning( logger.warning(
'table recognition processing fails, not found expected HTML table end' 'table recognition processing fails, not found expected HTML table end'
...@@ -186,8 +187,7 @@ class BatchAnalyze: ...@@ -186,8 +187,7 @@ class BatchAnalyze:
logger.warning( logger.warning(
'table recognition processing fails, not get html return' 'table recognition processing fails, not get html return'
) )
table_count += len(table_res_list_dict['table_res_list']) # logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
# logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {table_count}')
# Create dictionaries to store items by language # Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language need_ocr_lists_by_lang = {} # Dict of lists for each language
......
...@@ -33,7 +33,7 @@ class DocLayoutYOLOModel(object): ...@@ -33,7 +33,7 @@ class DocLayoutYOLOModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = [] images_layout_res = []
# for index in range(0, len(images), batch_size): # for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), total=len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0), desc="Layout Predict"): for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
doclayout_yolo_res = [ doclayout_yolo_res = [
image_res.cpu() image_res.cpu()
for image_res in self.model.predict( for image_res in self.model.predict(
......
...@@ -16,9 +16,7 @@ class YOLOv8MFDModel(object): ...@@ -16,9 +16,7 @@ class YOLOv8MFDModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = [] images_mfd_res = []
# for index in range(0, len(images), batch_size): # for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
total=len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0),
desc="MFD Predict"):
mfd_res = [ mfd_res = [
image_res.cpu() image_res.cpu()
for image_res in self.mfd_model.predict( for image_res in self.mfd_model.predict(
......
...@@ -109,13 +109,19 @@ class UnimernetModel(object): ...@@ -109,13 +109,19 @@ class UnimernetModel(object):
# Process batches and store results # Process batches and store results
mfr_res = [] mfr_res = []
# for mf_img in dataloader: # for mf_img in dataloader:
for mf_img in tqdm(dataloader, desc="MFR Predict"):
with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
for index, mf_img in enumerate(dataloader):
mf_img = mf_img.to(dtype=self.model.dtype) mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({"image": mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output["fixed_str"]) mfr_res.extend(output["fixed_str"])
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
pbar.update(current_batch_size)
# Restore original order # Restore original order
unsorted_results = [""] * len(mfr_res) unsorted_results = [""] * len(mfr_res)
for new_idx, latex in enumerate(mfr_res): for new_idx, latex in enumerate(mfr_res):
......
...@@ -72,6 +72,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -72,6 +72,7 @@ class PytorchPaddleOCR(TextSystem):
kwargs['det_model_path'] = os.path.join(ocr_models_dir, det) kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec) kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file) kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
# kwargs['rec_batch_num'] = 8
kwargs['device'] = get_device() kwargs['device'] = get_device()
......
...@@ -302,7 +302,9 @@ class TextRecognizer(BaseOCRV20): ...@@ -302,7 +302,9 @@ class TextRecognizer(BaseOCRV20):
batch_num = self.rec_batch_num batch_num = self.rec_batch_num
elapse = 0 elapse = 0
# for beg_img_no in range(0, img_num, batch_num): # for beg_img_no in range(0, img_num, batch_num):
for beg_img_no in tqdm(range(0, img_num, batch_num), desc='OCR-rec Predict', disable=not tqdm_enable): with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
index = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
max_wh_ratio = 0 max_wh_ratio = 0
...@@ -429,4 +431,10 @@ class TextRecognizer(BaseOCRV20): ...@@ -429,4 +431,10 @@ class TextRecognizer(BaseOCRV20):
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse += time.time() - starttime elapse += time.time() - starttime
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
current_batch_size = min(batch_num, img_num - index * batch_num)
index += 1
pbar.update(current_batch_size)
return rec_res, elapse return rec_res, elapse
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment