"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "6371f7af27c17b28a879d8af677362acba59bf51"
Unverified Commit bed386f7 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #964 from myhloli/dev

refactor(model): rename and restructure model modules
parents 8ddbe8bb c064379c
import numpy as np
from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR
class RapidTableModel(object):
def __init__(self):
self.table_model = RapidTable()
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
def predict(self, image):
ocr_result, _ = self.ocr_engine(np.asarray(image))
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
\ No newline at end of file
import re
import torch import torch
from struct_eqtable import build_model from struct_eqtable import build_model
from magic_pdf.model.sub_modules.table.table_utils import minify_html
class StructTableModel: class StructTableModel:
def __init__(self, model_path, max_new_tokens=1024, max_time=60): def __init__(self, model_path, max_new_tokens=1024, max_time=60):
...@@ -31,15 +31,7 @@ class StructTableModel: ...@@ -31,15 +31,7 @@ class StructTableModel:
) )
if output_format == "html": if output_format == "html":
results = [self.minify_html(html) for html in results] results = [minify_html(html) for html in results]
return results return results
def minify_html(self, html):
# 移除多余的空白字符
html = re.sub(r'\s+', ' ', html)
# 移除行尾的空白字符
html = re.sub(r'\s*>\s*', '>', html)
# 移除标签前的空白字符
html = re.sub(r'\s*<\s*', '<', html)
return html.strip()
\ No newline at end of file
import re
def minify_html(html):
# 移除多余的空白字符
html = re.sub(r'\s+', ' ', html)
# 移除行尾的空白字符
html = re.sub(r'\s*>\s*', '>', html)
# 移除标签前的空白字符
html = re.sub(r'\s*<\s*', '<', html)
return html.strip()
\ No newline at end of file
...@@ -7,7 +7,7 @@ from PIL import Image ...@@ -7,7 +7,7 @@ from PIL import Image
import numpy as np import numpy as np
class ppTableModel(object): class TableMasterPaddleModel(object):
""" """
This class is responsible for converting image of table into HTML format using a pre-trained model. This class is responsible for converting image of table into HTML format using a pre-trained model.
......
...@@ -164,8 +164,8 @@ class ModelSingleton: ...@@ -164,8 +164,8 @@ class ModelSingleton:
def do_predict(boxes: List[List[int]], model) -> List[int]: def do_predict(boxes: List[List[int]], model) -> List[int]:
from magic_pdf.model.v3.helpers import (boxes2inputs, parse_logits, from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (boxes2inputs, parse_logits,
prepare_inputs) prepare_inputs)
inputs = boxes2inputs(boxes) inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model) inputs = prepare_inputs(inputs, model)
...@@ -206,7 +206,7 @@ def cal_block_index(fix_blocks, sorted_bboxes): ...@@ -206,7 +206,7 @@ def cal_block_index(fix_blocks, sorted_bboxes):
del block['real_lines'] del block['real_lines']
import numpy as np import numpy as np
from magic_pdf.model.v3.xycut import recursive_xy_cut from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import recursive_xy_cut
random_boxes = np.array(block_bboxes) random_boxes = np.array(block_bboxes)
np.random.shuffle(random_boxes) np.random.shuffle(random_boxes)
......
...@@ -49,6 +49,7 @@ if __name__ == '__main__': ...@@ -49,6 +49,7 @@ if __name__ == '__main__':
"doclayout_yolo==0.0.2", # doclayout_yolo "doclayout_yolo==0.0.2", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle "rapidocr-paddle", # rapidocr-paddle
"rapid_table", # rapid_table "rapid_table", # rapid_table
"PyYAML", # yaml
"detectron2" "detectron2"
], ],
}, },
......
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
from PIL import Image from PIL import Image
from lxml import etree from lxml import etree
from magic_pdf.model.ppTableModel import ppTableModel from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
class TestppTableModel(unittest.TestCase): class TestppTableModel(unittest.TestCase):
...@@ -11,7 +11,7 @@ class TestppTableModel(unittest.TestCase): ...@@ -11,7 +11,7 @@ class TestppTableModel(unittest.TestCase):
# 修改table模型路径 # 修改table模型路径
config = {"device": "cuda", config = {"device": "cuda",
"model_dir": "/home/quyuan/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models/TabRec/TableMaster"} "model_dir": "/home/quyuan/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models/TabRec/TableMaster"}
table_model = ppTableModel(config) table_model = TableMasterPaddleModel(config)
res = table_model.img2html(img) res = table_model.img2html(img)
# 验证生成的 HTML 是否符合预期 # 验证生成的 HTML 是否符合预期
parser = etree.HTMLParser() parser = etree.HTMLParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment