app.py 7.58 KB
Newer Older
1
2
import json
import os
3
4
from io import StringIO
from typing import Tuple, Union
5
6

import uvicorn
7
from fastapi import FastAPI, HTTPException, UploadFile
8
9
from fastapi.responses import JSONResponse
from loguru import logger
10
11

import magic_pdf.model as model_config
icecraft's avatar
icecraft committed
12
from magic_pdf.config.enums import SupportedPdfParseMethod
13
14
from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
icecraft's avatar
icecraft committed
15
from magic_pdf.data.dataset import PymuDocDataset
16
from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
icecraft's avatar
icecraft committed
17
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
18
from magic_pdf.operators.models import InferenceResult
19
from magic_pdf.operators.pipes import PipeResult
20
21
22
23
24

model_config.__use_inside_model__ = True

app = FastAPI()

25

26

27
28
29
class MemoryDataWriter(DataWriter):
    def __init__(self):
        self.buffer = StringIO()
30

31
32
33
    def write(self, path: str, data: bytes) -> None:
        if isinstance(data, str):
            self.buffer.write(data)
34
        else:
35
            self.buffer.write(data.decode('utf-8'))
36

37
38
    def write_string(self, path: str, data: str) -> None:
        self.buffer.write(data)
39

40
41
    def get_value(self) -> str:
        return self.buffer.getvalue()
42

43
44
    def close(self):
        self.buffer.close()
45

46
47
48
49
50
51
52
53
def init_writers(
    pdf_path: str = None,
    pdf_file: UploadFile = None,
    output_path: str = None,
    output_image_path: str = None,
) -> Tuple[Union[S3DataWriter, FileBasedDataWriter], Union[S3DataWriter, FileBasedDataWriter], bytes]:
    """
    Initialize writers based on path type
icecraft's avatar
icecraft committed
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    Args:
        pdf_path: PDF file path (local path or S3 path)
        pdf_file: Uploaded PDF file object
        output_path: Output directory path
        output_image_path: Image output directory path

    Returns:
        Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF file content
    """
    if pdf_path:
        is_s3_path = pdf_path.startswith('s3://')
        if is_s3_path:
            bucket = get_bucket_name(pdf_path)
            ak, sk, endpoint = get_s3_config(bucket)

            writer = S3DataWriter(output_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
            image_writer = S3DataWriter(output_image_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
            # 临时创建reader读取文件内容
            temp_reader = S3DataReader("", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
            pdf_bytes = temp_reader.read(pdf_path)
icecraft's avatar
icecraft committed
75
        else:
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            writer = FileBasedDataWriter(output_path)
            image_writer = FileBasedDataWriter(output_image_path)
            os.makedirs(output_image_path, exist_ok=True)
            with open(pdf_path, 'rb') as f:
                pdf_bytes = f.read()
    else:
        # 处理上传的文件
        pdf_bytes = pdf_file.file.read()
        writer = FileBasedDataWriter(output_path)
        image_writer = FileBasedDataWriter(output_image_path)
        os.makedirs(output_image_path, exist_ok=True)

    return writer, image_writer, pdf_bytes

def process_pdf(
    pdf_bytes: bytes,
    parse_method: str,
    image_writer: Union[S3DataWriter, FileBasedDataWriter]
) -> Tuple[InferenceResult, PipeResult]:
    """
    Process PDF file content

    Args:
        pdf_bytes: Binary content of PDF file
        parse_method: Parse method ('ocr', 'txt', 'auto')
        image_writer: Image writer
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    Returns:
        Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
    """
    ds = PymuDocDataset(pdf_bytes)
    infer_result : InferenceResult = None
    pipe_result : PipeResult = None

    if parse_method == 'ocr':
        infer_result = ds.apply(doc_analyze, ocr=True)
        pipe_result = infer_result.pipe_ocr_mode(image_writer)
    elif parse_method == 'txt':
        infer_result = ds.apply(doc_analyze, ocr=False)
        pipe_result = infer_result.pipe_txt_mode(image_writer)
    else:  # auto
        if ds.classify() == SupportedPdfParseMethod.OCR:
            infer_result = ds.apply(doc_analyze, ocr=True)
            pipe_result = infer_result.pipe_ocr_mode(image_writer)
        else:
            infer_result = ds.apply(doc_analyze, ocr=False)
            pipe_result = infer_result.pipe_txt_mode(image_writer)
123

124
    return infer_result, pipe_result
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@app.post('/pdf_parse', tags=['projects'], summary='Parse PDF files (supports local files and S3)')
async def pdf_parse(
    pdf_file: UploadFile = None,
    pdf_path: str = None,
    parse_method: str = 'auto',
    is_json_md_dump: bool = True,
    output_dir: str = 'output',
    return_layout: bool = False,
    return_info: bool = False,
    return_content_list: bool = False,
):
    try:
        if pdf_file is None and pdf_path is None:
            raise HTTPException(status_code=400, detail="Must provide either pdf_file or pdf_path")

        # Get PDF filename
        pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split('.')[0]
        output_path = f"{output_dir}/{pdf_name}"
        output_image_path = f"{output_path}/images"

        # Initialize readers/writers and get PDF content
        writer, image_writer, pdf_bytes = init_writers(
            pdf_path=pdf_path,
            pdf_file=pdf_file,
            output_path=output_path,
            output_image_path=output_image_path
        )

        # Process PDF
        infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer)

        # Use MemoryDataWriter to get results
        content_list_writer = MemoryDataWriter()
        md_content_writer = MemoryDataWriter()
        middle_json_writer = MemoryDataWriter()

        # Use PipeResult's dump method to get data
        pipe_result.dump_content_list(content_list_writer, "", "images")
        pipe_result.dump_md(md_content_writer, "", "images")
        pipe_result.dump_middle_json(middle_json_writer, "")

        # Get content
        content_list = json.loads(content_list_writer.get_value())
        md_content = md_content_writer.get_value()
        middle_json = json.loads(middle_json_writer.get_value())
        model_json = infer_result.get_infer_res()

        # If results need to be saved
174
        if is_json_md_dump:
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            writer.write_string(f"{pdf_name}_content_list.json", content_list_writer.get_value())
            writer.write_string(f"{pdf_name}.md", md_content)
            writer.write_string(f"{pdf_name}_middle.json", middle_json_writer.get_value())
            writer.write_string(f"{pdf_name}_model.json", json.dumps(model_json, indent=4, ensure_ascii=False))
            # Save visualization results
            pipe_result.draw_layout(os.path.join(output_path, f'{pdf_name}_layout.pdf'))
            pipe_result.draw_span(os.path.join(output_path, f'{pdf_name}_spans.pdf'))
            pipe_result.draw_line_sort(os.path.join(output_path, f'{pdf_name}_line_sort.pdf'))
            infer_result.draw_model(os.path.join(output_path, f'{pdf_name}_model.pdf'))

        # Build return data
        data = {}
        if return_layout:
            data['layout'] = model_json
        if return_info:
            data['info'] = middle_json
        if return_content_list:
            data['content_list'] = content_list
        data['md_content'] = md_content  # md_content is always returned

        # Clean up memory writers
        content_list_writer.close()
        md_content_writer.close()
        middle_json_writer.close()

200
201
202
203
        return JSONResponse(data, status_code=200)

    except Exception as e:
        logger.exception(e)
204
        return JSONResponse(content={'error': str(e)}, status_code=500)
205

206
207
208

if __name__ == '__main__':
    uvicorn.run(app, host='0.0.0.0', port=8888)