server.py 3.1 KB
Newer Older
Hui's avatar
Hui committed
1
import os
2
3
4
5
import uuid
import shutil
import tempfile
import gc
Hui's avatar
Hui committed
6
7
8
import fitz
import torch
import base64
9
import filetype
Hui's avatar
Hui committed
10
import litserve as ls
11
from pathlib import Path
Hui's avatar
Hui committed
12
from fastapi import HTTPException
13
from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
Hui's avatar
Hui committed
14
15
16
17
18
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton


class MinerUAPI(ls.LitAPI):
    def __init__(self, output_dir='/tmp'):
19
        self.output_dir = Path(output_dir)
Hui's avatar
Hui committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33

    def setup(self, device):
        if device.startswith('cuda'):
            os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
            if torch.cuda.device_count() > 1:
                raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")

        model_manager = ModelSingleton()
        model_manager.get_model(True, False)
        model_manager.get_model(False, False)
        print(f'Model initialization complete on {device}!')

    def decode_request(self, request):
        file = request['file']
34
        file = self.cvt2pdf(file)
Hui's avatar
Hui committed
35
36
37
38
39
40
41
        opts = request.get('kwargs', {})
        opts.setdefault('debug_able', False)
        opts.setdefault('parse_method', 'auto')
        return file, opts

    def predict(self, inputs):
        try:
42
43
44
45
            pdf_name = str(uuid.uuid4())
            output_dir = self.output_dir.joinpath(pdf_name)
            do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
            return output_dir
Hui's avatar
Hui committed
46
        except Exception as e:
47
            shutil.rmtree(output_dir, ignore_errors=True)
Hui's avatar
Hui committed
48
49
50
51
52
53
54
55
56
57
58
59
60
            raise HTTPException(status_code=500, detail=str(e))
        finally:
            self.clean_memory()

    def encode_response(self, response):
        return {'output_dir': response}

    def clean_memory(self):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        gc.collect()

61
    def cvt2pdf(self, file_base64):
Hui's avatar
Hui committed
62
        try:
63
64
            temp_dir = Path(tempfile.mkdtemp())
            temp_file = temp_dir.joinpath('tmpfile')
Hui's avatar
Hui committed
65
            file_bytes = base64.b64decode(file_base64)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
            file_ext = filetype.guess_extension(file_bytes)

            if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
                if file_ext == 'pdf':
                    return file_bytes
                elif file_ext in ['jpg', 'png']:
                    with fitz.open(stream=file_bytes, filetype=file_ext) as f:
                        return f.convert_to_pdf()
                else:
                    temp_file.write_bytes(file_bytes)
                    convert_file_to_pdf(temp_file, temp_dir)
                    return temp_file.with_suffix('.pdf').read_bytes()
            else:
                raise Exception('Unsupported file format')
Hui's avatar
Hui committed
80
81
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
82
83
        finally:
            shutil.rmtree(temp_dir, ignore_errors=True)
Hui's avatar
Hui committed
84
85
86
87
88
89
90
91
92
93
94


if __name__ == '__main__':
    server = ls.LitServer(
        MinerUAPI(output_dir='/tmp'),
        accelerator='cuda',
        devices='auto',
        workers_per_device=1,
        timeout=False
    )
    server.run(port=8000)