Commit 11f23843 authored by myhloli's avatar myhloli
Browse files

feat(table): upgrade StructEqTable model and integrate into PDF Extract Kit

- Update StructTableModel to use the latest struct-eqtable library
- Add support for HTML table extraction in PDF Extract Kit
- Improve error handling and model initialization
- Update dependencies in setup.py for struct-eqtable
parent 314f1637
......@@ -38,15 +38,13 @@ except ImportError as e:
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
# from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
# table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
logger.error("StructEqTable is under upgrade, the current version does not support it.")
exit(1)
table_model = StructTableModel(model_path, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = {
"model_dir": model_path,
......@@ -393,7 +391,7 @@ class CustomPEKModel:
elif int(res['category_id']) in [5]:
table_res_list.append(res)
if torch.cuda.is_available():
if torch.cuda.is_available() and self.device != 'cpu':
properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 10:
......@@ -463,7 +461,9 @@ class CustomPEKModel:
html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
latex_code = self.table_model.image2latex(new_image)[0]
table_result = self.table_model.predict(new_image, "html")
if len(table_result) > 0:
html_code = table_result[0]
else:
html_code = self.table_model.img2html(new_image)
......@@ -474,14 +474,17 @@ class CustomPEKModel:
# 判断是否返回正常
if latex_code:
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
'end{table}')
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
if expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
elif html_code:
res["html"] = html_code
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
if expected_ending:
res["html"] = html_code
else:
logger.warning(f"table recognition processing fails, not found expected HTML table end")
else:
logger.warning(f"table recognition processing fails, not get latex or html return")
logger.info(f"table time: {round(time.time() - table_start, 2)}")
......
from loguru import logger
try:
from struct_eqtable.model import StructTable
except ImportError:
logger.error("StructEqTable is under upgrade, the current version does not support it.")
from pypandoc import convert_text
import torch
from struct_eqtable import build_model
class StructTableModel:
def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
def __init__(self, model_path, max_new_tokens=1024, max_time=60):
# init
self.model_path = model_path
self.max_new_tokens = max_new_tokens # maximum output tokens length
self.max_time = max_time # timeout for processing in seconds
if device == 'cuda':
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
self.model = build_model(
model_ckpt=model_path,
max_new_tokens=max_new_tokens,
max_time=max_time,
lmdeploy=False,
flash_attn=False,
batch_size=1,
).cuda()
self.default_format = "html"
def predict(self, images, output_format=None, **kwargs):
if output_format is None:
output_format = self.default_format
else:
self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
if output_format not in ['latex', 'markdown', 'html']:
raise ValueError(f"Output format {output_format} is not supported.")
def image2latex(self, image) -> str:
table_latex = self.model.forward(image)
return table_latex
results = self.model(
images, output_format=output_format
)
def image2html(self, image) -> str:
table_latex = self.image2latex(image)
table_html = convert_text(table_latex, 'html', format='latex')
return table_html
return results
......@@ -41,7 +41,7 @@ class ppTableModel(object):
pred_html = pred_res["html"]
res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace("</table></body></html>",
"") + "</table></td>\n"
return res
return pred_html
def parse_args(self, **kwargs):
parser = init_args()
......
......@@ -43,8 +43,9 @@ if __name__ == '__main__':
"paddleocr==2.7.3", # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
"paddlepaddle==3.0.0b1;platform_system=='Linux'", # 解决linux的段异常问题
"paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'", # windows版本3.0.0b1效率下降,需锁定2.6.1
"pypandoc", # 表格解析latex转html
"struct-eqtable==0.1.0", # 表格解析
"struct-eqtable==0.3.2", # 表格解析
"einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo
"detectron2"
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment