Commit b29badc1 authored by liukaiwen's avatar liukaiwen
Browse files

# add table recognition using struct-eqtable

## Changelog
31/07/20204
- Support table recognition. Table images will be converted into html.

### how to use the new feature:
set the attribute 'table-mode' to 'true' in magic-pdf.json

### caution:
it takes 200s to 500s to convert a single table image using cpu
parent 724db33d
......@@ -5,5 +5,6 @@
},
"temp-output-dir":"/tmp",
"models-dir":"/tmp/models",
"device-mode":"cpu"
"device-mode":"cpu",
"table-mode":"false"
}
\ No newline at end of file
......@@ -128,7 +128,11 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout, mode, img_buket_path=""):
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Table:
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
# if processed by table model
if span.get('content', ''):
para_text += f"\n {span['content']} \n"
else:
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block)
......@@ -244,6 +248,9 @@ def para_to_standard_format_v2(para_block, img_buket_path):
}
for block in para_block['blocks']:
if block['type'] == BlockType.TableBody:
#TODO
if block["lines"][0]["spans"][0].get('content', ''):
para_content['table_body'] = f"\n {block['lines'][0]['spans'][0]['content']} \n"
para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
if block['type'] == BlockType.TableCaption:
para_content['table_caption'] = merge_para_with_text(block)
......
......@@ -86,6 +86,23 @@ def get_device():
else:
return device
def get_table_mode():
config = read_config()
table_mode = config.get("table-mode")
if table_mode is None:
logger.warning(f"'table-mode' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return False
else:
table_mode = table_mode.lower()
if table_mode == "true":
boolean_value = True
elif table_mode == "False":
boolean_value = False
else:
logger.warning(f"invalid 'table-mode' value in {CONFIG_FILE_NAME}, use 'False' as default")
boolean_value = False
return boolean_value
if __name__ == "__main__":
ak, sk, endpoint = get_s3_config("llm-raw")
......@@ -4,7 +4,7 @@ import fitz
import numpy as np
from loguru import logger
from magic_pdf.libs.config_reader import get_local_models_dir, get_device
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_mode
from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config
......@@ -82,7 +82,13 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
# 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir()
device = get_device()
custom_model = CustomPEKModel(ocr=ocr, show_log=show_log, models_dir=local_models_dir, device=device)
table_mode = get_table_mode()
model_input = {"ocr": ocr,
"show_log": show_log,
"models_dir": local_models_dir,
"device": device,
"table_mode": table_mode}
custom_model = CustomPEKModel(**model_input)
else:
logger.error("Not allow model_name!")
exit(1)
......
......@@ -560,6 +560,14 @@ class MagicModel:
if category_id == 3:
span["type"] = ContentType.Image
elif category_id == 5:
# 获取table模型结果
html = layout_det.get("html", None)
latex = layout_det.get("latex", None)
if html:
span["content"] = html
elif latex:
span["content"] = latex
span["type"] = ContentType.Table
elif category_id == 13:
span["content"] = layout_det["latex"]
......
from loguru import logger
import os
import time
from pypandoc import convert_text
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
try:
......@@ -10,6 +11,7 @@ try:
import numpy as np
import torch
import torchtext
if torchtext.__version__ >= "0.18.0":
torchtext.disable_torchtext_deprecation_warning()
from PIL import Image
......@@ -30,6 +32,12 @@ except ImportError as e:
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
def table_model_init(model_path):
table_model = StructTableModel(model_path)
return table_model
def mfd_model_init(weight):
......@@ -95,6 +103,7 @@ class CustomPEKModel:
# 初始化解析配置
self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
self.apply_table = kwargs.get("table_mode", self.configs["config"]["table"])
self.apply_ocr = ocr
logger.info(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
......@@ -129,6 +138,9 @@ class CustomPEKModel:
if self.apply_ocr:
self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
# init structeqtable
if self.apply_table:
self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])))
logger.info('DocAnalysis init done!')
def __call__(self, image):
......@@ -249,4 +261,39 @@ class CustomPEKModel:
ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr cost: {ocr_cost}")
# 表格识别 table recognition
if self.apply_table:
pil_img = Image.fromarray(image)
for layout in layout_res:
if layout.get("category_id", -1) == 5:
poly = layout["poly"]
xmin, ymin = int(poly[0]), int(poly[1])
xmax, ymax = int(poly[4]), int(poly[5])
paste_x = 50
paste_y = 50
# 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
new_width = xmax - xmin + paste_x * 2
new_height = ymax - ymin + paste_y * 2
new_image = Image.new('RGB', (new_width, new_height), 'white')
# 裁剪图像 crop image
crop_box = (xmin, ymin, xmax, ymax)
cropped_img = pil_img.crop(crop_box)
new_image.paste(cropped_img, (paste_x, paste_y))
start_time = time.time()
print("------------------table recognition processing begins-----------------")
latex_code = self.table_model.image2latex(new_image)[0]
end_time = time.time()
run_time = end_time - start_time
print(f"------------table recognition processing ends within {run_time}s-----")
# try to convert latex to html
try:
html_code = convert_text(latex_code, 'html', format='latex')
layout["html"] = html_code
except Exception as e:
layout["latex"] = latex_code
logger.error(f"[pdf_extract_kit][CustomPEKModel]: converting latex to html failed: {e}")
return layout_res
from struct_eqtable.model import StructTable
from pypandoc import convert_text
class StructTableModel:
def __init__(self, model_path, max_new_tokens=2048, max_time=400):
# init
self.model_path = model_path
self.max_new_tokens = max_new_tokens # maximum output tokens length
self.max_time = max_time # timeout for processing in seconds
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
def image2latex(self, image) -> str:
#
table_latex = self.model.forward(image)
return table_latex
def image2html(self, image) -> str:
table_latex = self.image2latex(image)
table_html = convert_text(table_latex, 'html', format='latex')
return table_html
......@@ -2,8 +2,10 @@ config:
device: cpu
layout: True
formula: True
table: False
weights:
layout: Layout/model_final.pth
mfd: MFD/weights.pt
mfr: MFR/UniMERNet
table: Table/
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment