StructTableModel.py 1.35 KB
Newer Older
1
2
import re

3
4
import torch
from struct_eqtable import build_model
5
6


7
class StructTableModel:
8
    def __init__(self, model_path, max_new_tokens=1024, max_time=60):
9
        # init
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
        assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
        self.model = build_model(
            model_ckpt=model_path,
            max_new_tokens=max_new_tokens,
            max_time=max_time,
            lmdeploy=False,
            flash_attn=False,
            batch_size=1,
        ).cuda()
        self.default_format = "html"

    def predict(self, images, output_format=None, **kwargs):

        if output_format is None:
            output_format = self.default_format
25
        else:
26
27
            if output_format not in ['latex', 'markdown', 'html']:
                raise ValueError(f"Output format {output_format} is not supported.")
28

29
30
31
        results = self.model(
            images, output_format=output_format
        )
32

33
34
35
        if output_format == "html":
            results = [self.minify_html(html) for html in results]

36
        return results
37
38
39
40
41
42
43
44
45

    def minify_html(self, html):
        # 移除多余的空白字符
        html = re.sub(r'\s+', ' ', html)
        # 移除行尾的空白字符
        html = re.sub(r'\s*>\s*', '>', html)
        # 移除标签前的空白字符
        html = re.sub(r'\s*<\s*', '<', html)
        return html.strip()