Commit 29e590a7 authored by myhloli's avatar myhloli
Browse files

refactor(pdf_extract_kit): optimize image processing and table recognition...

refactor(pdf_extract_kit): optimize image processing and table recognition logicRefactor the image processing logic for OCR and table recognition to ensure
consistency and improve performance. Remove redundant initialization of PIL images,
unify image cropping logic, and streamline the handling of formula detection results.
Also, adjust the table recognition process to improve integration with the updated image
processing logic and enhance overall efficiency.
parent ad5596fc
...@@ -27,7 +27,7 @@ except ImportError as e: ...@@ -27,7 +27,7 @@ except ImportError as e:
logger.exception(e) logger.exception(e)
logger.error( logger.error(
'Required dependency not installed, please install by \n' 'Required dependency not installed, please install by \n'
'"pip install magic-pdf[full] detectron2 --extra-index-url https://myhloli.github.io/wheels/"') '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
exit(1) exit(1)
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
...@@ -188,50 +188,56 @@ class CustomPEKModel: ...@@ -188,50 +188,56 @@ class CustomPEKModel:
mfr_cost = round(time.time() - mfr_start, 2) mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}") logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# Select regions for OCR / formula regions / table regions
ocr_res_list = []
table_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
elif int(res['category_id']) in [5]:
table_res_list.append(res)
# Unified crop img logic
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list
pil_img = Image.fromarray(image)
# ocr识别 # ocr识别
if self.apply_ocr: if self.apply_ocr:
ocr_start = time.time() ocr_start = time.time()
pil_img = Image.fromarray(image) # Process each area that requires OCR processing
# 筛选出需要OCR的区域和公式区域
ocr_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
# 对每一个需OCR处理的区域进行处理
for res in ocr_res_list: for res in ocr_res_list:
xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# Adjust the coordinates of the formula area
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景
new_width = xmax - xmin + paste_x * 2
new_height = ymax - ymin + paste_y * 2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
# 调整公式区域坐标
adjusted_mfdetrec_res = [] adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res: for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# 将公式区域坐标调整为相对于裁剪区域的坐标 # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y y1 = mf_ymax - ymin + paste_y
# 过滤在图外的公式块 # Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue continue
else: else:
...@@ -239,17 +245,17 @@ class CustomPEKModel: ...@@ -239,17 +245,17 @@ class CustomPEKModel:
"bbox": [x0, y0, x1, y1], "bbox": [x0, y0, x1, y1],
}) })
# OCR识别 # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
# 整合结果 # Integration results
if ocr_res: if ocr_res:
for box_ocr_res in ocr_res: for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0] p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1] text, score = box_ocr_res[1]
# 将坐标转换回原图坐标系 # Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
...@@ -267,35 +273,23 @@ class CustomPEKModel: ...@@ -267,35 +273,23 @@ class CustomPEKModel:
# 表格识别 table recognition # 表格识别 table recognition
if self.apply_table: if self.apply_table:
pil_img = Image.fromarray(image) table_start = time.time()
for layout in layout_res: for res in table_res_list:
if layout.get("category_id", -1) == 5: new_image, _ = crop_img(res, pil_img)
poly = layout["poly"] single_table_start_time = time.time()
xmin, ymin = int(poly[0]), int(poly[1]) logger.info("------------------table recognition processing begins-----------------")
xmax, ymax = int(poly[4]), int(poly[5]) with torch.no_grad():
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
new_width = xmax - xmin + paste_x * 2
new_height = ymax - ymin + paste_y * 2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像 crop image
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
start_time = time.time()
logger.info("------------------table recognition processing begins-----------------")
latex_code = self.table_model.image2latex(new_image)[0] latex_code = self.table_model.image2latex(new_image)[0]
end_time = time.time() run_time = time.time() - single_table_start_time
run_time = end_time - start_time logger.info(f"------------table recognition processing ends within {run_time}s-----")
logger.info(f"------------table recognition processing ends within {run_time}s-----") if run_time > self.table_max_time:
if run_time > self.table_max_time: logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------") # 判断是否返回正常
# 判断是否返回正常 if latex_code and latex_code.strip().endswith('end{tabular}'):
if latex_code and latex_code.strip().endswith('end{tabular}'): res["latex"] = latex_code
layout["latex"] = latex_code else:
else: logger.warning(f"------------table recognition processing fails----------")
logger.warning(f"------------table recognition processing fails----------") table_cost = round(time.time() - table_start, 2)
logger.info(f"table cost: {table_cost}")
return layout_res return layout_res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment