StructTableModel.py 960 Bytes
Newer Older
1
2
import torch
from struct_eqtable import build_model
3
4


5
class StructTableModel:
6
    def __init__(self, model_path, max_new_tokens=1024, max_time=60):
7
        # init
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
        assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
        self.model = build_model(
            model_ckpt=model_path,
            max_new_tokens=max_new_tokens,
            max_time=max_time,
            lmdeploy=False,
            flash_attn=False,
            batch_size=1,
        ).cuda()
        self.default_format = "html"

    def predict(self, images, output_format=None, **kwargs):

        if output_format is None:
            output_format = self.default_format
23
        else:
24
25
            if output_format not in ['latex', 'markdown', 'html']:
                raise ValueError(f"Output format {output_format} is not supported.")
26

27
28
29
        results = self.model(
            images, output_format=output_format
        )
30

31
        return results