Unverified Commit eae0e6d8 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge branch 'opendatalab:dev' into dev

parents c46d3373 b3faee93
......@@ -34,8 +34,6 @@ from magic_pdf.model.model_list import MODEL
# from magic_pdf.operators.models import InferenceResult
MIN_BATCH_INFERENCE_SIZE = 100
class ModelSingleton:
_instance = None
_models = {}
......@@ -143,17 +141,14 @@ def doc_analyze(
layout_model=None,
formula_enable=None,
table_enable=None,
one_shot: bool = True,
):
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1
)
parallel_count = None
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
images = []
page_wh_list = []
for index in range(len(dataset)):
......@@ -163,41 +158,16 @@ def doc_analyze(
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None:
parallel_count = 2 # should check the gpu memory firstly !
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
_, results = may_batch_image_analyze(
images,
0,
ocr,
show_log,
lang, layout_model, formula_enable, table_enable)
batch_images = [images]
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
results.extend(result)
model_json = []
for index in range(len(dataset)):
......@@ -224,11 +194,8 @@ def batch_doc_analyze(
layout_model=None,
formula_enable=None,
table_enable=None,
one_shot: bool = True,
):
parallel_count = None
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
images = []
page_wh_list = []
for dataset in datasets:
......@@ -238,40 +205,17 @@ def batch_doc_analyze(
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None:
parallel_count = 2 # should check the gpu memory firstly !
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
_, results = may_batch_image_analyze(
images,
0,
ocr,
show_log,
lang, layout_model, formula_enable, table_enable)
batch_images = [images]
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
results.extend(result)
infer_results = []
from magic_pdf.operators.models import InferenceResult
......
......@@ -314,7 +314,7 @@ def batch_do_parse(
dss.append(PymuDocDataset(v, lang=lang))
else:
dss.append(v)
infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True)
infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
for idx, infer_result in enumerate(infer_results):
_do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
......
......@@ -3,6 +3,7 @@ import os
from base64 import b64encode
from glob import glob
from io import StringIO
import tempfile
from typing import Tuple, Union
import uvicorn
......@@ -10,11 +11,12 @@ from fastapi import FastAPI, HTTPException, UploadFile
from fastapi.responses import JSONResponse
from loguru import logger
from magic_pdf.data.read_api import read_local_images, read_local_office
import magic_pdf.model as model_config
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.operators.models import InferenceResult
......@@ -24,6 +26,9 @@ model_config.__use_inside_model__ = True
app = FastAPI()
pdf_extensions = [".pdf"]
office_extensions = [".ppt", ".pptx", ".doc", ".docx"]
image_extensions = [".png", ".jpg"]
class MemoryDataWriter(DataWriter):
def __init__(self):
......@@ -46,8 +51,8 @@ class MemoryDataWriter(DataWriter):
def init_writers(
pdf_path: str = None,
pdf_file: UploadFile = None,
file_path: str = None,
file: UploadFile = None,
output_path: str = None,
output_image_path: str = None,
) -> Tuple[
......@@ -59,19 +64,19 @@ def init_writers(
Initialize writers based on path type
Args:
pdf_path: PDF file path (local path or S3 path)
pdf_file: Uploaded PDF file object
file_path: file path (local path or S3 path)
file: Uploaded file object
output_path: Output directory path
output_image_path: Image output directory path
Returns:
Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF
file content
Tuple[writer, image_writer, file_bytes]: Returns initialized writer tuple and file content
"""
if pdf_path:
is_s3_path = pdf_path.startswith("s3://")
file_extension:str = None
if file_path:
is_s3_path = file_path.startswith("s3://")
if is_s3_path:
bucket = get_bucket_name(pdf_path)
bucket = get_bucket_name(file_path)
ak, sk, endpoint = get_s3_config(bucket)
writer = S3DataWriter(
......@@ -84,25 +89,29 @@ def init_writers(
temp_reader = S3DataReader(
"", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
)
pdf_bytes = temp_reader.read(pdf_path)
file_bytes = temp_reader.read(file_path)
file_extension = os.path.splitext(file_path)[1]
else:
writer = FileBasedDataWriter(output_path)
image_writer = FileBasedDataWriter(output_image_path)
os.makedirs(output_image_path, exist_ok=True)
with open(pdf_path, "rb") as f:
pdf_bytes = f.read()
with open(file_path, "rb") as f:
file_bytes = f.read()
file_extension = os.path.splitext(file_path)[1]
else:
# 处理上传的文件
pdf_bytes = pdf_file.file.read()
file_bytes = file.file.read()
file_extension = os.path.splitext(file.filename)[1]
writer = FileBasedDataWriter(output_path)
image_writer = FileBasedDataWriter(output_image_path)
os.makedirs(output_image_path, exist_ok=True)
return writer, image_writer, pdf_bytes
return writer, image_writer, file_bytes, file_extension
def process_pdf(
pdf_bytes: bytes,
def process_file(
file_bytes: bytes,
file_extension: str,
parse_method: str,
image_writer: Union[S3DataWriter, FileBasedDataWriter],
) -> Tuple[InferenceResult, PipeResult]:
......@@ -110,14 +119,30 @@ def process_pdf(
Process PDF file content
Args:
pdf_bytes: Binary content of PDF file
file_bytes: Binary content of file
file_extension: file extension
parse_method: Parse method ('ocr', 'txt', 'auto')
image_writer: Image writer
Returns:
Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
"""
ds = PymuDocDataset(pdf_bytes)
ds = Union[PymuDocDataset, ImageDataset]
if file_extension in pdf_extensions:
ds = PymuDocDataset(file_bytes)
elif file_extension in office_extensions:
# 需要使用office解析
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
f.write(file_bytes)
ds = read_local_office(temp_dir)[0]
elif file_extension in image_extensions:
# 需要使用ocr解析
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
f.write(file_bytes)
ds = read_local_images(temp_dir)[0]
infer_result: InferenceResult = None
pipe_result: PipeResult = None
......@@ -145,13 +170,13 @@ def encode_image(image_path: str) -> str:
@app.post(
"/pdf_parse",
"/file_parse",
tags=["projects"],
summary="Parse PDF files (supports local files and S3)",
summary="Parse files (supports local files and S3)",
)
async def pdf_parse(
pdf_file: UploadFile = None,
pdf_path: str = None,
async def file_parse(
file: UploadFile = None,
file_path: str = None,
parse_method: str = "auto",
is_json_md_dump: bool = False,
output_dir: str = "output",
......@@ -165,10 +190,10 @@ async def pdf_parse(
to the specified directory.
Args:
pdf_file: The PDF file to be parsed. Must not be specified together with
`pdf_path`
pdf_path: The path to the PDF file to be parsed. Must not be specified together
with `pdf_file`
file: The PDF file to be parsed. Must not be specified together with
`file_path`
file_path: The path to the PDF file to be parsed. Must not be specified together
with `file`
parse_method: Parsing method, can be auto, ocr, or txt. Default is auto. If
results are not satisfactory, try ocr
is_json_md_dump: Whether to write parsed data to .json and .md files. Default
......@@ -181,31 +206,31 @@ async def pdf_parse(
return_content_list: Whether to return parsed PDF content list. Default to False
"""
try:
if (pdf_file is None and pdf_path is None) or (
pdf_file is not None and pdf_path is not None
if (file is None and file_path is None) or (
file is not None and file_path is not None
):
return JSONResponse(
content={"error": "Must provide either pdf_file or pdf_path"},
content={"error": "Must provide either file or file_path"},
status_code=400,
)
# Get PDF filename
pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split(
file_name = os.path.basename(file_path if file_path else file.filename).split(
"."
)[0]
output_path = f"{output_dir}/{pdf_name}"
output_path = f"{output_dir}/{file_name}"
output_image_path = f"{output_path}/images"
# Initialize readers/writers and get PDF content
writer, image_writer, pdf_bytes = init_writers(
pdf_path=pdf_path,
pdf_file=pdf_file,
writer, image_writer, file_bytes, file_extension = init_writers(
file_path=file_path,
file=file,
output_path=output_path,
output_image_path=output_image_path,
)
# Process PDF
infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer)
infer_result, pipe_result = process_file(file_bytes, file_extension, parse_method, image_writer)
# Use MemoryDataWriter to get results
content_list_writer = MemoryDataWriter()
......@@ -226,23 +251,23 @@ async def pdf_parse(
# If results need to be saved
if is_json_md_dump:
writer.write_string(
f"{pdf_name}_content_list.json", content_list_writer.get_value()
f"{file_name}_content_list.json", content_list_writer.get_value()
)
writer.write_string(f"{pdf_name}.md", md_content)
writer.write_string(f"{file_name}.md", md_content)
writer.write_string(
f"{pdf_name}_middle.json", middle_json_writer.get_value()
f"{file_name}_middle.json", middle_json_writer.get_value()
)
writer.write_string(
f"{pdf_name}_model.json",
f"{file_name}_model.json",
json.dumps(model_json, indent=4, ensure_ascii=False),
)
# Save visualization results
pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf"))
pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf"))
pipe_result.draw_layout(os.path.join(output_path, f"{file_name}_layout.pdf"))
pipe_result.draw_span(os.path.join(output_path, f"{file_name}_spans.pdf"))
pipe_result.draw_line_sort(
os.path.join(output_path, f"{pdf_name}_line_sort.pdf")
os.path.join(output_path, f"{file_name}_line_sort.pdf")
)
infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf"))
infer_result.draw_model(os.path.join(output_path, f"{file_name}_model.pdf"))
# Build return data
data = {}
......
......@@ -183,6 +183,30 @@
"created_at": "2025-02-26T09:23:25Z",
"repoId": 765083837,
"pullRequestNo": 1785
},
{
"name": "rschutski",
"id": 179498169,
"comment_id": 2705150371,
"created_at": "2025-03-06T23:16:30Z",
"repoId": 765083837,
"pullRequestNo": 1863
},
{
"name": "qbit-",
"id": 4794088,
"comment_id": 2705914730,
"created_at": "2025-03-07T09:09:13Z",
"repoId": 765083837,
"pullRequestNo": 1863
},
{
"name": "mauryaland",
"id": 22381129,
"comment_id": 2717322316,
"created_at": "2025-03-12T10:03:11Z",
"repoId": 765083837,
"pullRequestNo": 1906
}
]
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment