Unverified Commit eae0e6d8 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge branch 'opendatalab:dev' into dev

parents c46d3373 b3faee93
...@@ -34,8 +34,6 @@ from magic_pdf.model.model_list import MODEL ...@@ -34,8 +34,6 @@ from magic_pdf.model.model_list import MODEL
# from magic_pdf.operators.models import InferenceResult # from magic_pdf.operators.models import InferenceResult
MIN_BATCH_INFERENCE_SIZE = 100
class ModelSingleton: class ModelSingleton:
_instance = None _instance = None
_models = {} _models = {}
...@@ -143,17 +141,14 @@ def doc_analyze( ...@@ -143,17 +141,14 @@ def doc_analyze(
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
one_shot: bool = True,
): ):
end_page_id = ( end_page_id = (
end_page_id end_page_id
if end_page_id is not None and end_page_id >= 0 if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1 else len(dataset) - 1
) )
parallel_count = None
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
images = [] images = []
page_wh_list = [] page_wh_list = []
for index in range(len(dataset)): for index in range(len(dataset)):
...@@ -163,41 +158,16 @@ def doc_analyze( ...@@ -163,41 +158,16 @@ def doc_analyze(
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE: if len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None: batch_size = MIN_BATCH_INFERENCE_SIZE
parallel_count = 2 # should check the gpu memory firstly ! batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
else: else:
_, results = may_batch_image_analyze( batch_images = [images]
images,
0, results = []
ocr, for sn, batch_image in enumerate(batch_images):
show_log, _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
lang, layout_model, formula_enable, table_enable) results.extend(result)
model_json = [] model_json = []
for index in range(len(dataset)): for index in range(len(dataset)):
...@@ -224,11 +194,8 @@ def batch_doc_analyze( ...@@ -224,11 +194,8 @@ def batch_doc_analyze(
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
one_shot: bool = True,
): ):
parallel_count = None MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
images = [] images = []
page_wh_list = [] page_wh_list = []
for dataset in datasets: for dataset in datasets:
...@@ -238,40 +205,17 @@ def batch_doc_analyze( ...@@ -238,40 +205,17 @@ def batch_doc_analyze(
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE: if len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None: batch_size = MIN_BATCH_INFERENCE_SIZE
parallel_count = 2 # should check the gpu memory firstly ! batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
else: else:
_, results = may_batch_image_analyze( batch_images = [images]
images,
0, results = []
ocr,
show_log, for sn, batch_image in enumerate(batch_images):
lang, layout_model, formula_enable, table_enable) _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
results.extend(result)
infer_results = [] infer_results = []
from magic_pdf.operators.models import InferenceResult from magic_pdf.operators.models import InferenceResult
......
...@@ -314,7 +314,7 @@ def batch_do_parse( ...@@ -314,7 +314,7 @@ def batch_do_parse(
dss.append(PymuDocDataset(v, lang=lang)) dss.append(PymuDocDataset(v, lang=lang))
else: else:
dss.append(v) dss.append(v)
infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True) infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
for idx, infer_result in enumerate(infer_results): for idx, infer_result in enumerate(infer_results):
_do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox) _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
from base64 import b64encode from base64 import b64encode
from glob import glob from glob import glob
from io import StringIO from io import StringIO
import tempfile
from typing import Tuple, Union from typing import Tuple, Union
import uvicorn import uvicorn
...@@ -10,11 +11,12 @@ from fastapi import FastAPI, HTTPException, UploadFile ...@@ -10,11 +11,12 @@ from fastapi import FastAPI, HTTPException, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from loguru import logger from loguru import logger
from magic_pdf.data.read_api import read_local_images, read_local_office
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter from magic_pdf.data.data_reader_writer import DataWriter, FileBasedDataWriter
from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter from magic_pdf.data.data_reader_writer.s3 import S3DataReader, S3DataWriter
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config from magic_pdf.libs.config_reader import get_bucket_name, get_s3_config
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.operators.models import InferenceResult from magic_pdf.operators.models import InferenceResult
...@@ -24,6 +26,9 @@ model_config.__use_inside_model__ = True ...@@ -24,6 +26,9 @@ model_config.__use_inside_model__ = True
app = FastAPI() app = FastAPI()
pdf_extensions = [".pdf"]
office_extensions = [".ppt", ".pptx", ".doc", ".docx"]
image_extensions = [".png", ".jpg"]
class MemoryDataWriter(DataWriter): class MemoryDataWriter(DataWriter):
def __init__(self): def __init__(self):
...@@ -46,8 +51,8 @@ class MemoryDataWriter(DataWriter): ...@@ -46,8 +51,8 @@ class MemoryDataWriter(DataWriter):
def init_writers( def init_writers(
pdf_path: str = None, file_path: str = None,
pdf_file: UploadFile = None, file: UploadFile = None,
output_path: str = None, output_path: str = None,
output_image_path: str = None, output_image_path: str = None,
) -> Tuple[ ) -> Tuple[
...@@ -59,19 +64,19 @@ def init_writers( ...@@ -59,19 +64,19 @@ def init_writers(
Initialize writers based on path type Initialize writers based on path type
Args: Args:
pdf_path: PDF file path (local path or S3 path) file_path: file path (local path or S3 path)
pdf_file: Uploaded PDF file object file: Uploaded file object
output_path: Output directory path output_path: Output directory path
output_image_path: Image output directory path output_image_path: Image output directory path
Returns: Returns:
Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF Tuple[writer, image_writer, file_bytes]: Returns initialized writer tuple and file content
file content
""" """
if pdf_path: file_extension:str = None
is_s3_path = pdf_path.startswith("s3://") if file_path:
is_s3_path = file_path.startswith("s3://")
if is_s3_path: if is_s3_path:
bucket = get_bucket_name(pdf_path) bucket = get_bucket_name(file_path)
ak, sk, endpoint = get_s3_config(bucket) ak, sk, endpoint = get_s3_config(bucket)
writer = S3DataWriter( writer = S3DataWriter(
...@@ -84,25 +89,29 @@ def init_writers( ...@@ -84,25 +89,29 @@ def init_writers(
temp_reader = S3DataReader( temp_reader = S3DataReader(
"", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint "", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
) )
pdf_bytes = temp_reader.read(pdf_path) file_bytes = temp_reader.read(file_path)
file_extension = os.path.splitext(file_path)[1]
else: else:
writer = FileBasedDataWriter(output_path) writer = FileBasedDataWriter(output_path)
image_writer = FileBasedDataWriter(output_image_path) image_writer = FileBasedDataWriter(output_image_path)
os.makedirs(output_image_path, exist_ok=True) os.makedirs(output_image_path, exist_ok=True)
with open(pdf_path, "rb") as f: with open(file_path, "rb") as f:
pdf_bytes = f.read() file_bytes = f.read()
file_extension = os.path.splitext(file_path)[1]
else: else:
# 处理上传的文件 # 处理上传的文件
pdf_bytes = pdf_file.file.read() file_bytes = file.file.read()
file_extension = os.path.splitext(file.filename)[1]
writer = FileBasedDataWriter(output_path) writer = FileBasedDataWriter(output_path)
image_writer = FileBasedDataWriter(output_image_path) image_writer = FileBasedDataWriter(output_image_path)
os.makedirs(output_image_path, exist_ok=True) os.makedirs(output_image_path, exist_ok=True)
return writer, image_writer, pdf_bytes return writer, image_writer, file_bytes, file_extension
def process_pdf( def process_file(
pdf_bytes: bytes, file_bytes: bytes,
file_extension: str,
parse_method: str, parse_method: str,
image_writer: Union[S3DataWriter, FileBasedDataWriter], image_writer: Union[S3DataWriter, FileBasedDataWriter],
) -> Tuple[InferenceResult, PipeResult]: ) -> Tuple[InferenceResult, PipeResult]:
...@@ -110,14 +119,30 @@ def process_pdf( ...@@ -110,14 +119,30 @@ def process_pdf(
Process PDF file content Process PDF file content
Args: Args:
pdf_bytes: Binary content of PDF file file_bytes: Binary content of file
file_extension: file extension
parse_method: Parse method ('ocr', 'txt', 'auto') parse_method: Parse method ('ocr', 'txt', 'auto')
image_writer: Image writer image_writer: Image writer
Returns: Returns:
Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
""" """
ds = PymuDocDataset(pdf_bytes)
ds = Union[PymuDocDataset, ImageDataset]
if file_extension in pdf_extensions:
ds = PymuDocDataset(file_bytes)
elif file_extension in office_extensions:
# 需要使用office解析
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
f.write(file_bytes)
ds = read_local_office(temp_dir)[0]
elif file_extension in image_extensions:
# 需要使用ocr解析
temp_dir = tempfile.mkdtemp()
with open(os.path.join(temp_dir, f"temp_file.{file_extension}"), "wb") as f:
f.write(file_bytes)
ds = read_local_images(temp_dir)[0]
infer_result: InferenceResult = None infer_result: InferenceResult = None
pipe_result: PipeResult = None pipe_result: PipeResult = None
...@@ -145,13 +170,13 @@ def encode_image(image_path: str) -> str: ...@@ -145,13 +170,13 @@ def encode_image(image_path: str) -> str:
@app.post( @app.post(
"/pdf_parse", "/file_parse",
tags=["projects"], tags=["projects"],
summary="Parse PDF files (supports local files and S3)", summary="Parse files (supports local files and S3)",
) )
async def pdf_parse( async def file_parse(
pdf_file: UploadFile = None, file: UploadFile = None,
pdf_path: str = None, file_path: str = None,
parse_method: str = "auto", parse_method: str = "auto",
is_json_md_dump: bool = False, is_json_md_dump: bool = False,
output_dir: str = "output", output_dir: str = "output",
...@@ -165,10 +190,10 @@ async def pdf_parse( ...@@ -165,10 +190,10 @@ async def pdf_parse(
to the specified directory. to the specified directory.
Args: Args:
pdf_file: The PDF file to be parsed. Must not be specified together with file: The PDF file to be parsed. Must not be specified together with
`pdf_path` `file_path`
pdf_path: The path to the PDF file to be parsed. Must not be specified together file_path: The path to the PDF file to be parsed. Must not be specified together
with `pdf_file` with `file`
parse_method: Parsing method, can be auto, ocr, or txt. Default is auto. If parse_method: Parsing method, can be auto, ocr, or txt. Default is auto. If
results are not satisfactory, try ocr results are not satisfactory, try ocr
is_json_md_dump: Whether to write parsed data to .json and .md files. Default is_json_md_dump: Whether to write parsed data to .json and .md files. Default
...@@ -181,31 +206,31 @@ async def pdf_parse( ...@@ -181,31 +206,31 @@ async def pdf_parse(
return_content_list: Whether to return parsed PDF content list. Default to False return_content_list: Whether to return parsed PDF content list. Default to False
""" """
try: try:
if (pdf_file is None and pdf_path is None) or ( if (file is None and file_path is None) or (
pdf_file is not None and pdf_path is not None file is not None and file_path is not None
): ):
return JSONResponse( return JSONResponse(
content={"error": "Must provide either pdf_file or pdf_path"}, content={"error": "Must provide either file or file_path"},
status_code=400, status_code=400,
) )
# Get PDF filename # Get PDF filename
pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split( file_name = os.path.basename(file_path if file_path else file.filename).split(
"." "."
)[0] )[0]
output_path = f"{output_dir}/{pdf_name}" output_path = f"{output_dir}/{file_name}"
output_image_path = f"{output_path}/images" output_image_path = f"{output_path}/images"
# Initialize readers/writers and get PDF content # Initialize readers/writers and get PDF content
writer, image_writer, pdf_bytes = init_writers( writer, image_writer, file_bytes, file_extension = init_writers(
pdf_path=pdf_path, file_path=file_path,
pdf_file=pdf_file, file=file,
output_path=output_path, output_path=output_path,
output_image_path=output_image_path, output_image_path=output_image_path,
) )
# Process PDF # Process PDF
infer_result, pipe_result = process_pdf(pdf_bytes, parse_method, image_writer) infer_result, pipe_result = process_file(file_bytes, file_extension, parse_method, image_writer)
# Use MemoryDataWriter to get results # Use MemoryDataWriter to get results
content_list_writer = MemoryDataWriter() content_list_writer = MemoryDataWriter()
...@@ -226,23 +251,23 @@ async def pdf_parse( ...@@ -226,23 +251,23 @@ async def pdf_parse(
# If results need to be saved # If results need to be saved
if is_json_md_dump: if is_json_md_dump:
writer.write_string( writer.write_string(
f"{pdf_name}_content_list.json", content_list_writer.get_value() f"{file_name}_content_list.json", content_list_writer.get_value()
) )
writer.write_string(f"{pdf_name}.md", md_content) writer.write_string(f"{file_name}.md", md_content)
writer.write_string( writer.write_string(
f"{pdf_name}_middle.json", middle_json_writer.get_value() f"{file_name}_middle.json", middle_json_writer.get_value()
) )
writer.write_string( writer.write_string(
f"{pdf_name}_model.json", f"{file_name}_model.json",
json.dumps(model_json, indent=4, ensure_ascii=False), json.dumps(model_json, indent=4, ensure_ascii=False),
) )
# Save visualization results # Save visualization results
pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf")) pipe_result.draw_layout(os.path.join(output_path, f"{file_name}_layout.pdf"))
pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf")) pipe_result.draw_span(os.path.join(output_path, f"{file_name}_spans.pdf"))
pipe_result.draw_line_sort( pipe_result.draw_line_sort(
os.path.join(output_path, f"{pdf_name}_line_sort.pdf") os.path.join(output_path, f"{file_name}_line_sort.pdf")
) )
infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf")) infer_result.draw_model(os.path.join(output_path, f"{file_name}_model.pdf"))
# Build return data # Build return data
data = {} data = {}
......
...@@ -183,6 +183,30 @@ ...@@ -183,6 +183,30 @@
"created_at": "2025-02-26T09:23:25Z", "created_at": "2025-02-26T09:23:25Z",
"repoId": 765083837, "repoId": 765083837,
"pullRequestNo": 1785 "pullRequestNo": 1785
},
{
"name": "rschutski",
"id": 179498169,
"comment_id": 2705150371,
"created_at": "2025-03-06T23:16:30Z",
"repoId": 765083837,
"pullRequestNo": 1863
},
{
"name": "qbit-",
"id": 4794088,
"comment_id": 2705914730,
"created_at": "2025-03-07T09:09:13Z",
"repoId": 765083837,
"pullRequestNo": 1863
},
{
"name": "mauryaland",
"id": 22381129,
"comment_id": 2717322316,
"created_at": "2025-03-12T10:03:11Z",
"repoId": 765083837,
"pullRequestNo": 1906
} }
] ]
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment