server.py 3.21 KB
Newer Older
Hui's avatar
Hui committed
1
import os
2
3
4
5
import uuid
import shutil
import tempfile
import gc
Hui's avatar
Hui committed
6
7
8
import fitz
import torch
import base64
9
import filetype
Hui's avatar
Hui committed
10
import litserve as ls
11
from pathlib import Path
Hui's avatar
Hui committed
12
13
14
15
16
from fastapi import HTTPException


class MinerUAPI(ls.LitAPI):
    def __init__(self, output_dir='/tmp'):
17
        self.output_dir = Path(output_dir)
Hui's avatar
Hui committed
18
19
20
21
22
23
24

    def setup(self, device):
        if device.startswith('cuda'):
            os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
            if torch.cuda.device_count() > 1:
                raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")

25
26
27
28
29
30
        from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
        from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton

        self.do_parse = do_parse
        self.convert_file_to_pdf = convert_file_to_pdf

Hui's avatar
Hui committed
31
32
33
34
35
36
37
        model_manager = ModelSingleton()
        model_manager.get_model(True, False)
        model_manager.get_model(False, False)
        print(f'Model initialization complete on {device}!')

    def decode_request(self, request):
        file = request['file']
38
        file = self.cvt2pdf(file)
Hui's avatar
Hui committed
39
40
41
42
43
44
45
        opts = request.get('kwargs', {})
        opts.setdefault('debug_able', False)
        opts.setdefault('parse_method', 'auto')
        return file, opts

    def predict(self, inputs):
        try:
46
47
            pdf_name = str(uuid.uuid4())
            output_dir = self.output_dir.joinpath(pdf_name)
48
            self.do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
49
            return output_dir
Hui's avatar
Hui committed
50
        except Exception as e:
51
            shutil.rmtree(output_dir, ignore_errors=True)
Hui's avatar
Hui committed
52
53
54
55
56
57
58
59
60
61
62
63
64
            raise HTTPException(status_code=500, detail=str(e))
        finally:
            self.clean_memory()

    def encode_response(self, response):
        return {'output_dir': response}

    def clean_memory(self):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        gc.collect()

65
    def cvt2pdf(self, file_base64):
Hui's avatar
Hui committed
66
        try:
67
68
            temp_dir = Path(tempfile.mkdtemp())
            temp_file = temp_dir.joinpath('tmpfile')
Hui's avatar
Hui committed
69
            file_bytes = base64.b64decode(file_base64)
70
71
72
73
74
75
76
77
78
79
            file_ext = filetype.guess_extension(file_bytes)

            if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
                if file_ext == 'pdf':
                    return file_bytes
                elif file_ext in ['jpg', 'png']:
                    with fitz.open(stream=file_bytes, filetype=file_ext) as f:
                        return f.convert_to_pdf()
                else:
                    temp_file.write_bytes(file_bytes)
80
                    self.convert_file_to_pdf(temp_file, temp_dir)
81
82
83
                    return temp_file.with_suffix('.pdf').read_bytes()
            else:
                raise Exception('Unsupported file format')
Hui's avatar
Hui committed
84
85
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
86
87
        finally:
            shutil.rmtree(temp_dir, ignore_errors=True)
Hui's avatar
Hui committed
88
89
90
91
92
93
94
95
96
97
98


if __name__ == '__main__':
    server = ls.LitServer(
        MinerUAPI(output_dir='/tmp'),
        accelerator='cuda',
        devices='auto',
        workers_per_device=1,
        timeout=False
    )
    server.run(port=8000)