Commit 102fe277 authored by JesseChen1031's avatar JesseChen1031
Browse files

add support for more document types

parent ce67ccf8
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
from base64 import b64encode from base64 import b64encode
from glob import glob from glob import glob
from io import StringIO from io import StringIO
import tempfile
from typing import Tuple, Union from typing import Tuple, Union
import uvicorn import uvicorn
...@@ -10,11 +11,12 @@ from fastapi import FastAPI, HTTPException, UploadFile ...@@ -10,11 +11,12 @@ from fastapi import FastAPI, HTTPException, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from loguru import logger from loguru import logger
from magic_pdf.data.read_api import read_local_images, read_local_office
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.operators.models import InferenceResult from magic_pdf.operators.models import InferenceResult
...@@ -24,6 +26,9 @@ model_config.__use_inside_model__ = True ...@@ -24,6 +26,9 @@ model_config.__use_inside_model__ = True
app = FastAPI() app = FastAPI()
pdf_extensions = [".pdf"]
office_extensions = [".ppt", ".pptx", ".doc", ".docx"]
image_extensions = [".png", ".jpg"]
class MemoryDataWriter(DataWriter): class MemoryDataWriter(DataWriter):
def __init__(self): def __init__(self):
...@@ -46,8 +51,8 @@ class MemoryDataWriter(DataWriter): ...@@ -46,8 +51,8 @@ class MemoryDataWriter(DataWriter):
def init_writers( def init_writers(
pdf_path: str = None, file_path: str = None,
pdf_file: UploadFile = None, file: UploadFile = None,
output_path: str = None, output_path: str = None,
output_image_path: str = None, output_image_path: str = None,
) -> Tuple[ ) -> Tuple[
...@@ -68,10 +73,11 @@ def init_writers( ...@@ -68,10 +73,11 @@ def init_writers(
Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF
file content file content
""" """
if pdf_path: file_extension:str = None
is_s3_path = pdf_path.startswith("s3://") if file_path:
is_s3_path = file_path.startswith("s3://")
if is_s3_path: if is_s3_path:
bucket = get_bucket_name(pdf_path) bucket = get_bucket_name(file_path)
ak, sk, endpoint = get_s3_config(bucket) ak, sk, endpoint = get_s3_config(bucket)
writer = S3DataWriter( writer = S3DataWriter(
...@@ -84,25 +90,29 @@ def init_writers( ...@@ -84,25 +90,29 @@ def init_writers(
temp_reader = S3DataReader( temp_reader = S3DataReader(
"", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint "", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
) )
pdf_bytes = temp_reader.read(pdf_path) file_bytes = temp_reader.read(file_path)
file_extension = os.path.splitext(file_path)[1]
else: else:
writer = FileBasedDataWriter(output_path) writer = FileBasedDataWriter(output_path)
image_writer = FileBasedDataWriter(output_image_path) image_writer = FileBasedDataWriter(output_image_path)
os.makedirs(output_image_path, exist_ok=True) os.makedirs(output_image_path, exist_ok=True)
with open(pdf_path, "rb") as f: with open(file_path, "rb") as f:
pdf_bytes = f.read() file_bytes = f.read()
file_extension = os.path.splitext(file_path)[1]
else: else:
# 处理上传的文件 # 处理上传的文件
pdf_bytes = pdf_file.file.read() file_bytes = file.file.read()
file_extension = os.path.splitext(file.filename)[1]
writer = FileBasedDataWriter(output_path) writer = FileBasedDataWriter(output_path)
image_writer = FileBasedDataWriter(output_image_path) image_writer = FileBasedDataWriter(output_image_path)
os.makedirs(output_image_path, exist_ok=True) os.makedirs(output_image_path, exist_ok=True)
return writer, image_writer, pdf_bytes return writer, image_writer, file_bytes, file_extension
def process_pdf( def process_file(
pdf_bytes: bytes, file_bytes: bytes,
file_extension: str,
parse_method: str, parse_method: str,
image_writer: Union[S3DataWriter, FileBasedDataWriter], image_writer: Union[S3DataWriter, FileBasedDataWriter],
) -> Tuple[InferenceResult, PipeResult]: ) -> Tuple[InferenceResult, PipeResult]:
...@@ -117,7 +127,22 @@ def process_pdf( ...@@ -117,7 +127,22 @@ def process_pdf(
Returns: Returns:
Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
""" """
ds = PymuDocDataset(pdf_bytes)
ds = Union[PymuDocDataset, ImageDataset]
if file_extension in pdf_extensions:
ds = PymuDocDataset(file_bytes)
elif file_extension in office_extensions:
# 需要使用office解析
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
f.write(file_bytes)
ds = read_local_office(temp_dir)[0]
elif file_extension in image_extensions:
# 需要使用ocr解析
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
f.write(file_bytes)
ds = read_local_images(temp_dir)[0]
infer_result: InferenceResult = None infer_result: InferenceResult = None
pipe_result: PipeResult = None pipe_result: PipeResult = None
...@@ -149,9 +174,9 @@ def encode_image(image_path: str) -> str: ...@@ -149,9 +174,9 @@ def encode_image(image_path: str) -> str:
tags=["projects"], tags=["projects"],
summary="Parse PDF files (supports local files and S3)", summary="Parse PDF files (supports local files and S3)",
) )
async def pdf_parse( async def file_parse(
pdf_file: UploadFile = None, file: UploadFile = None,
pdf_path: str = None, file_path: str = None,
parse_method: str = "auto", parse_method: str = "auto",
is_json_md_dump: bool = False, is_json_md_dump: bool = False,
output_dir: str = "output", output_dir: str = "output",
...@@ -181,31 +206,31 @@ async def pdf_parse( ...@@ -181,31 +206,31 @@ async def pdf_parse(
return_content_list: Whether to return parsed PDF content list. Default to False return_content_list: Whether to return parsed PDF content list. Default to False
""" """
try: try:
if (pdf_file is None and pdf_path is None) or ( if (file is None and file_path is None) or (
pdf_file is not None and pdf_path is not None file is not None and file_path is not None
): ):
return JSONResponse( return JSONResponse(
content={"error": "Must provide either pdf_file or pdf_path"}, content={"error": "Must provide either file or file_path"},
status_code=400, status_code=400,
) )
# Get PDF filename # Get PDF filename
pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split( file_name = os.path.basename(file_path if file_path else file.filename).split(
"." "."
)[0] )[0]
output_path = f"{output_dir}/{pdf_name}" output_path = f"{output_dir}/{file_name}"
output_image_path = f"{output_path}/images" output_image_path = f"{output_path}/images"
# Initialize readers/writers and get PDF content # Initialize readers/writers and get PDF content
writer, image_writer, pdf_bytes = init_writers( writer, image_writer, file_bytes, file_extension = init_writers(
pdf_path=pdf_path, file_path=file_path,
pdf_file=pdf_file, file=file,
output_path=output_path, output_path=output_path,
output_image_path=output_image_path, output_image_path=output_image_path,
) )
# Process PDF # Process PDF
infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer) infer_result, pipe_result = process_file(file_bytes, file_extension, parse_method, image_writer)
# Use MemoryDataWriter to get results # Use MemoryDataWriter to get results
content_list_writer = MemoryDataWriter() content_list_writer = MemoryDataWriter()
...@@ -226,23 +251,23 @@ async def pdf_parse( ...@@ -226,23 +251,23 @@ async def pdf_parse(
# If results need to be saved # If results need to be saved
if is_json_md_dump: if is_json_md_dump:
writer.write_string( writer.write_string(
f"{pdf_name}_content_list.json", content_list_writer.get_value() f"{file_name}_content_list.json", content_list_writer.get_value()
) )
writer.write_string(f"{pdf_name}.md", md_content) writer.write_string(f"{file_name}.md", md_content)
writer.write_string( writer.write_string(
f"{pdf_name}_middle.json", middle_json_writer.get_value() f"{file_name}_middle.json", middle_json_writer.get_value()
) )
writer.write_string( writer.write_string(
f"{pdf_name}_model.json", f"{file_name}_model.json",
json.dumps(model_json, indent=4, ensure_ascii=False), json.dumps(model_json, indent=4, ensure_ascii=False),
) )
# Save visualization results # Save visualization results
pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf")) pipe_result.draw_layout(os.path.join(output_path, f"{file_name}_layout.pdf"))
pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf")) pipe_result.draw_span(os.path.join(output_path, f"{file_name}_spans.pdf"))
pipe_result.draw_line_sort( pipe_result.draw_line_sort(
os.path.join(output_path, f"{pdf_name}_line_sort.pdf") os.path.join(output_path, f"{file_name}_line_sort.pdf")
) )
infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf")) infer_result.draw_model(os.path.join(output_path, f"{file_name}_model.pdf"))
# Build return data # Build return data
data = {} data = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment