"src/vscode:/vscode.git/clone" did not exist on "45a09bebf38a201e92002739a63eaa4b1f608920"
Commit 1fd72f5f authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): optimize table recognition and layout detection

- Update table recognition logic to process each table individually
- Refactor layout detection to use tqdm for progress tracking
- Optimize OCR recognition by using a single tqdm wrapper
- Improve MFR prediction with a more accurate progress bar
- Simplify MFD prediction by removing unnecessary total calculation
parent 795233d1
......@@ -102,10 +102,13 @@ class BatchAnalyze:
'single_page_mfdetrec_res':single_page_mfdetrec_res,
'layout_res':layout_res,
})
table_res_list_all_page.append({'table_res_list':table_res_list,
'lang':_lang,
'np_array_img':np_array_img,
})
for table_res in table_res_list:
table_img, _ = crop_img(table_res, np_array_img)
table_res_list_all_page.append({'table_res':table_res,
'lang':_lang,
'table_img':table_img,
})
# 文本框检测
det_start = time.time()
......@@ -149,8 +152,8 @@ class BatchAnalyze:
table_start = time.time()
table_count = 0
# for table_res_list_dict in table_res_list_all_page:
for table_res_list_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_list_dict['lang']
for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
_lang = table_res_dict['lang']
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
......@@ -168,26 +171,23 @@ class BatchAnalyze:
ocr_engine=ocr_engine,
table_sub_model_name='slanet_plus'
)
for res in table_res_list_dict['table_res_list']:
new_image, _ = crop_img(res, table_res_list_dict['np_array_img'])
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(new_image)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
res['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
table_res_dict['table_res']['html'] = html_code
else:
logger.warning(
'table recognition processing fails, not get html return'
'table recognition processing fails, not found expected HTML table end'
)
table_count += len(table_res_list_dict['table_res_list'])
# logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {table_count}')
else:
logger.warning(
'table recognition processing fails, not get html return'
)
# logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
......
......@@ -33,7 +33,7 @@ class DocLayoutYOLOModel(object):
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), total=len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0), desc="Layout Predict"):
for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
......
......@@ -16,9 +16,7 @@ class YOLOv8MFDModel(object):
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size),
total=len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0),
desc="MFD Predict"):
for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
......
......@@ -109,12 +109,18 @@ class UnimernetModel(object):
# Process batches and store results
mfr_res = []
# for mf_img in dataloader:
for mf_img in tqdm(dataloader, desc="MFR Predict"):
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["fixed_str"])
with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
for index, mf_img in enumerate(dataloader):
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["fixed_str"])
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
pbar.update(current_batch_size)
# Restore original order
unsorted_results = [""] * len(mfr_res)
......
......@@ -72,6 +72,7 @@ class PytorchPaddleOCR(TextSystem):
kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
# kwargs['rec_batch_num'] = 8
kwargs['device'] = get_device()
......
......@@ -302,131 +302,139 @@ class TextRecognizer(BaseOCRV20):
batch_num = self.rec_batch_num
elapse = 0
# for beg_img_no in range(0, img_num, batch_num):
for beg_img_no in tqdm(range(0, img_num, batch_num), desc='OCR-rec Predict', disable=not tqdm_enable):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
# h, w = img_list[ino].shape[0:2]
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SVTR":
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
self.max_text_length)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
index = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
# h, w = img_list[ino].shape[0:2]
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SVTR":
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
self.max_text_length)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
elif self.rec_algorithm == "CAN":
norm_img = self.norm_img_can(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_image_mask = np.ones(norm_img.shape, dtype='float32')
word_label = np.ones([1, 36], dtype='int64')
norm_img_mask_batch = []
word_label_list = []
norm_img_mask_batch.append(norm_image_mask)
word_label_list.append(word_label)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
inp = inp.to(self.device)
encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
backbone_out = self.net.backbone(inp) # backbone_feat
prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
# preds = {"predict": prob_out[2]}
preds = {"predict": prob_out["predict"]}
elif self.rec_algorithm == "SAR":
starttime = time.time()
# valid_ratios = np.concatenate(valid_ratios)
# inputs = [
# norm_img_batch,
# valid_ratios,
# ]
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
preds = self.net(inp)
elif self.rec_algorithm == "CAN":
norm_img = self.norm_img_can(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_image_mask = np.ones(norm_img.shape, dtype='float32')
word_label = np.ones([1, 36], dtype='int64')
norm_img_mask_batch = []
word_label_list = []
norm_img_mask_batch.append(norm_image_mask)
word_label_list.append(word_label)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
inp = inp.to(self.device)
encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
backbone_out = self.net.backbone(inp) # backbone_feat
prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
# preds = {"predict": prob_out[2]}
preds = {"predict": prob_out["predict"]}
elif self.rec_algorithm == "SAR":
starttime = time.time()
# valid_ratios = np.concatenate(valid_ratios)
# inputs = [
# norm_img_batch,
# valid_ratios,
# ]
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
preds = self.net(inp)
elif self.rec_algorithm == "CAN":
starttime = time.time()
norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
word_label_list = np.concatenate(word_label_list)
inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
inp = [torch.from_numpy(e_i) for e_i in inputs]
inp = [e_i.to(self.device) for e_i in inp]
with torch.no_grad():
outputs = self.net(inp)
outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
preds = outputs
starttime = time.time()
norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
word_label_list = np.concatenate(word_label_list)
inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
else:
starttime = time.time()
inp = [torch.from_numpy(e_i) for e_i in inputs]
inp = [e_i.to(self.device) for e_i in inp]
with torch.no_grad():
outputs = self.net(inp)
outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
prob_out = self.net(inp)
preds = outputs
if isinstance(prob_out, list):
preds = [v.cpu().numpy() for v in prob_out]
else:
preds = prob_out.cpu().numpy()
starttime = time.time()
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
prob_out = self.net(inp)
if isinstance(prob_out, list):
preds = [v.cpu().numpy() for v in prob_out]
else:
preds = prob_out.cpu().numpy()
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse += time.time() - starttime
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
current_batch_size = min(batch_num, img_num - index * batch_num)
index += 1
pbar.update(current_batch_size)
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse += time.time() - starttime
return rec_res, elapse
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment