Commit 11f23843 authored by myhloli's avatar myhloli
Browse files

feat(table): upgrade StructEqTable model and integrate into PDF Extract Kit

- Update StructTableModel to use the latest struct-eqtable library
- Add support for HTML table extraction in PDF Extract Kit
- Improve error handling and model initialization
- Update dependencies in setup.py for struct-eqtable
parent 314f1637
...@@ -38,15 +38,13 @@ except ImportError as e: ...@@ -38,15 +38,13 @@ except ImportError as e:
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
# from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
from magic_pdf.model.ppTableModel import ppTableModel from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
# table_model = StructTableModel(model_path, max_time=max_time, device=_device_) table_model = StructTableModel(model_path, max_time=max_time)
logger.error("StructEqTable is under upgrade, the current version does not support it.")
exit(1)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = { config = {
"model_dir": model_path, "model_dir": model_path,
...@@ -393,7 +391,7 @@ class CustomPEKModel: ...@@ -393,7 +391,7 @@ class CustomPEKModel:
elif int(res['category_id']) in [5]: elif int(res['category_id']) in [5]:
table_res_list.append(res) table_res_list.append(res)
if torch.cuda.is_available(): if torch.cuda.is_available() and self.device != 'cpu':
properties = torch.cuda.get_device_properties(self.device) properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 10: if total_memory <= 10:
...@@ -463,7 +461,9 @@ class CustomPEKModel: ...@@ -463,7 +461,9 @@ class CustomPEKModel:
html_code = None html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad(): with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0] table_result = self.table_model.predict(new_image, "html")
if len(table_result) > 0:
html_code = table_result[0]
else: else:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
...@@ -474,14 +474,17 @@ class CustomPEKModel: ...@@ -474,14 +474,17 @@ class CustomPEKModel:
# 判断是否返回正常 # 判断是否返回正常
if latex_code: if latex_code:
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith( expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
'end{table}')
if expected_ending: if expected_ending:
res["latex"] = latex_code res["latex"] = latex_code
else: else:
logger.warning(f"table recognition processing fails, not found expected LaTeX table end") logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
elif html_code: elif html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
if expected_ending:
res["html"] = html_code res["html"] = html_code
else:
logger.warning(f"table recognition processing fails, not found expected HTML table end")
else: else:
logger.warning(f"table recognition processing fails, not get latex or html return") logger.warning(f"table recognition processing fails, not get latex or html return")
logger.info(f"table time: {round(time.time() - table_start, 2)}") logger.info(f"table time: {round(time.time() - table_start, 2)}")
......
from loguru import logger import torch
from struct_eqtable import build_model
try:
from struct_eqtable.model import StructTable
except ImportError:
logger.error("StructEqTable is under upgrade, the current version does not support it.")
from pypandoc import convert_text
class StructTableModel: class StructTableModel:
def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'): def __init__(self, model_path, max_new_tokens=1024, max_time=60):
# init # init
self.model_path = model_path assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
self.max_new_tokens = max_new_tokens # maximum output tokens length self.model = build_model(
self.max_time = max_time # timeout for processing in seconds model_ckpt=model_path,
if device == 'cuda': max_new_tokens=max_new_tokens,
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda() max_time=max_time,
lmdeploy=False,
flash_attn=False,
batch_size=1,
).cuda()
self.default_format = "html"
def predict(self, images, output_format=None, **kwargs):
if output_format is None:
output_format = self.default_format
else: else:
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time) if output_format not in ['latex', 'markdown', 'html']:
raise ValueError(f"Output format {output_format} is not supported.")
def image2latex(self, image) -> str: results = self.model(
table_latex = self.model.forward(image) images, output_format=output_format
return table_latex )
def image2html(self, image) -> str: return results
table_latex = self.image2latex(image)
table_html = convert_text(table_latex, 'html', format='latex')
return table_html
...@@ -41,7 +41,7 @@ class ppTableModel(object): ...@@ -41,7 +41,7 @@ class ppTableModel(object):
pred_html = pred_res["html"] pred_html = pred_res["html"]
res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace("</table></body></html>", res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace("</table></body></html>",
"") + "</table></td>\n" "") + "</table></td>\n"
return res return pred_html
def parse_args(self, **kwargs): def parse_args(self, **kwargs):
parser = init_args() parser = init_args()
......
...@@ -43,8 +43,9 @@ if __name__ == '__main__': ...@@ -43,8 +43,9 @@ if __name__ == '__main__':
"paddleocr==2.7.3", # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3 "paddleocr==2.7.3", # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
"paddlepaddle==3.0.0b1;platform_system=='Linux'", # 解决linux的段异常问题 "paddlepaddle==3.0.0b1;platform_system=='Linux'", # 解决linux的段异常问题
"paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'", # windows版本3.0.0b1效率下降,需锁定2.6.1 "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'", # windows版本3.0.0b1效率下降,需锁定2.6.1
"pypandoc", # 表格解析latex转html "struct-eqtable==0.3.2", # 表格解析
"struct-eqtable==0.1.0", # 表格解析 "einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo "doclayout_yolo==0.0.2", # doclayout_yolo
"detectron2" "detectron2"
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment