Unverified Commit 69e0e00e authored by shniubobo's avatar shniubobo
Browse files

refactor(web_api): Format code

parent c734f4de
......@@ -23,7 +23,6 @@ model_config.__use_inside_model__ = True
app = FastAPI()
class MemoryDataWriter(DataWriter):
def __init__(self):
self.buffer = StringIO()
......@@ -32,7 +31,7 @@ class MemoryDataWriter(DataWriter):
if isinstance(data, str):
self.buffer.write(data)
else:
self.buffer.write(data.decode('utf-8'))
self.buffer.write(data.decode("utf-8"))
def write_string(self, path: str, data: str) -> None:
self.buffer.write(data)
......@@ -43,12 +42,17 @@ class MemoryDataWriter(DataWriter):
def close(self):
self.buffer.close()
def init_writers(
pdf_path: str = None,
pdf_file: UploadFile = None,
output_path: str = None,
output_image_path: str = None,
) -> Tuple[Union[S3DataWriter, FileBasedDataWriter], Union[S3DataWriter, FileBasedDataWriter], bytes]:
) -> Tuple[
Union[S3DataWriter, FileBasedDataWriter],
Union[S3DataWriter, FileBasedDataWriter],
bytes,
]:
"""
Initialize writers based on path type
......@@ -59,24 +63,31 @@ def init_writers(
output_image_path: Image output directory path
Returns:
Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF file content
Tuple[writer, image_writer, pdf_bytes]: Returns initialized writer tuple and PDF
file content
"""
if pdf_path:
is_s3_path = pdf_path.startswith('s3://')
is_s3_path = pdf_path.startswith("s3://")
if is_s3_path:
bucket = get_bucket_name(pdf_path)
ak, sk, endpoint = get_s3_config(bucket)
writer = S3DataWriter(output_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
image_writer = S3DataWriter(output_image_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
writer = S3DataWriter(
output_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
)
image_writer = S3DataWriter(
output_image_path, bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
)
# 临时创建reader读取文件内容
temp_reader = S3DataReader("", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint)
temp_reader = S3DataReader(
"", bucket=bucket, ak=ak, sk=sk, endpoint_url=endpoint
)
pdf_bytes = temp_reader.read(pdf_path)
else:
writer = FileBasedDataWriter(output_path)
image_writer = FileBasedDataWriter(output_image_path)
os.makedirs(output_image_path, exist_ok=True)
with open(pdf_path, 'rb') as f:
with open(pdf_path, "rb") as f:
pdf_bytes = f.read()
else:
# 处理上传的文件
......@@ -87,10 +98,11 @@ def init_writers(
return writer, image_writer, pdf_bytes
def process_pdf(
pdf_bytes: bytes,
parse_method: str,
image_writer: Union[S3DataWriter, FileBasedDataWriter]
image_writer: Union[S3DataWriter, FileBasedDataWriter],
) -> Tuple[InferenceResult, PipeResult]:
"""
Process PDF file content
......@@ -104,13 +116,13 @@ def process_pdf(
Tuple[InferenceResult, PipeResult]: Returns inference result and pipeline result
"""
ds = PymuDocDataset(pdf_bytes)
infer_result : InferenceResult = None
pipe_result : PipeResult = None
infer_result: InferenceResult = None
pipe_result: PipeResult = None
if parse_method == 'ocr':
if parse_method == "ocr":
infer_result = ds.apply(doc_analyze, ocr=True)
pipe_result = infer_result.pipe_ocr_mode(image_writer)
elif parse_method == 'txt':
elif parse_method == "txt":
infer_result = ds.apply(doc_analyze, ocr=False)
pipe_result = infer_result.pipe_txt_mode(image_writer)
else: # auto
......@@ -123,23 +135,32 @@ def process_pdf(
return infer_result, pipe_result
@app.post('/pdf_parse', tags=['projects'], summary='Parse PDF files (supports local files and S3)')
@app.post(
"/pdf_parse",
tags=["projects"],
summary="Parse PDF files (supports local files and S3)",
)
async def pdf_parse(
pdf_file: UploadFile = None,
pdf_path: str = None,
parse_method: str = 'auto',
parse_method: str = "auto",
is_json_md_dump: bool = True,
output_dir: str = 'output',
output_dir: str = "output",
return_layout: bool = False,
return_info: bool = False,
return_content_list: bool = False,
):
try:
if pdf_file is None and pdf_path is None:
raise HTTPException(status_code=400, detail="Must provide either pdf_file or pdf_path")
raise HTTPException(
status_code=400, detail="Must provide either pdf_file or pdf_path"
)
# Get PDF filename
pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split('.')[0]
pdf_name = os.path.basename(pdf_path if pdf_path else pdf_file.filename).split(
"."
)[0]
output_path = f"{output_dir}/{pdf_name}"
output_image_path = f"{output_path}/images"
......@@ -148,7 +169,7 @@ async def pdf_parse(
pdf_path=pdf_path,
pdf_file=pdf_file,
output_path=output_path,
output_image_path=output_image_path
output_image_path=output_image_path,
)
# Process PDF
......@@ -172,25 +193,34 @@ async def pdf_parse(
# If results need to be saved
if is_json_md_dump:
writer.write_string(f"{pdf_name}_content_list.json", content_list_writer.get_value())
writer.write_string(
f"{pdf_name}_content_list.json", content_list_writer.get_value()
)
writer.write_string(f"{pdf_name}.md", md_content)
writer.write_string(f"{pdf_name}_middle.json", middle_json_writer.get_value())
writer.write_string(f"{pdf_name}_model.json", json.dumps(model_json, indent=4, ensure_ascii=False))
writer.write_string(
f"{pdf_name}_middle.json", middle_json_writer.get_value()
)
writer.write_string(
f"{pdf_name}_model.json",
json.dumps(model_json, indent=4, ensure_ascii=False),
)
# Save visualization results
pipe_result.draw_layout(os.path.join(output_path, f'{pdf_name}_layout.pdf'))
pipe_result.draw_span(os.path.join(output_path, f'{pdf_name}_spans.pdf'))
pipe_result.draw_line_sort(os.path.join(output_path, f'{pdf_name}_line_sort.pdf'))
infer_result.draw_model(os.path.join(output_path, f'{pdf_name}_model.pdf'))
pipe_result.draw_layout(os.path.join(output_path, f"{pdf_name}_layout.pdf"))
pipe_result.draw_span(os.path.join(output_path, f"{pdf_name}_spans.pdf"))
pipe_result.draw_line_sort(
os.path.join(output_path, f"{pdf_name}_line_sort.pdf")
)
infer_result.draw_model(os.path.join(output_path, f"{pdf_name}_model.pdf"))
# Build return data
data = {}
if return_layout:
data['layout'] = model_json
data["layout"] = model_json
if return_info:
data['info'] = middle_json
data["info"] = middle_json
if return_content_list:
data['content_list'] = content_list
data['md_content'] = md_content # md_content is always returned
data["content_list"] = content_list
data["md_content"] = md_content # md_content is always returned
# Clean up memory writers
content_list_writer.close()
......@@ -201,8 +231,8 @@ async def pdf_parse(
except Exception as e:
logger.exception(e)
return JSONResponse(content={'error': str(e)}, status_code=500)
return JSONResponse(content={"error": str(e)}, status_code=500)
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8888)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8888)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment