Commit 667d2c0d authored by Hui's avatar Hui
Browse files

Adjust projects/multi_gpu/server.py for magic_pdf-1.0.1

parent e41d7be3
......@@ -31,7 +31,7 @@ python server.py
### 2. 启动客户端
以下代码展示了客户端的使用方式,可根据需求修改配置:
```python
files = ['demo/small_ocr.pdf'] # 替换为文件路径,支持 jpg/jpeg、png、pdf 文件
files = ['demo/small_ocr.pdf'] # 替换为文件路径,支持 pdf、jpg/jpeg、png、doc、docx、ppt、pptx 文件
n_jobs = np.clip(len(files), 1, 8) # 设置并发线程数,此处最大为 8,可根据自身修改
results = Parallel(n_jobs, prefer='threads', verbose=10)(
delayed(do_parse)(p) for p in files
......
......@@ -31,7 +31,7 @@ def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs):
if __name__ == '__main__':
files = ['small_ocr.pdf']
files = ['demo/small_ocr.pdf']
n_jobs = np.clip(len(files), 1, 8)
results = Parallel(n_jobs, prefer='threads', verbose=10)(
delayed(do_parse)(p) for p in files
......
import os
import uuid
import shutil
import tempfile
import gc
import fitz
import torch
import base64
import filetype
import litserve as ls
from uuid import uuid4
from pathlib import Path
from fastapi import HTTPException
from filetype import guess_extension
from magic_pdf.tools.common import do_parse
from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
class MinerUAPI(ls.LitAPI):
def __init__(self, output_dir='/tmp'):
self.output_dir = output_dir
self.output_dir = Path(output_dir)
def setup(self, device):
if device.startswith('cuda'):
......@@ -27,7 +31,7 @@ class MinerUAPI(ls.LitAPI):
def decode_request(self, request):
file = request['file']
file = self.to_pdf(file)
file = self.cvt2pdf(file)
opts = request.get('kwargs', {})
opts.setdefault('debug_able', False)
opts.setdefault('parse_method', 'auto')
......@@ -35,9 +39,12 @@ class MinerUAPI(ls.LitAPI):
def predict(self, inputs):
try:
do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1])
return pdf_name
pdf_name = str(uuid.uuid4())
output_dir = self.output_dir.joinpath(pdf_name)
do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
return output_dir
except Exception as e:
shutil.rmtree(output_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail=str(e))
finally:
self.clean_memory()
......@@ -46,21 +53,34 @@ class MinerUAPI(ls.LitAPI):
return {'output_dir': response}
def clean_memory(self):
import gc
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def to_pdf(self, file_base64):
def cvt2pdf(self, file_base64):
try:
temp_dir = Path(tempfile.mkdtemp())
temp_file = temp_dir.joinpath('tmpfile')
file_bytes = base64.b64decode(file_base64)
file_ext = guess_extension(file_bytes)
with fitz.open(stream=file_bytes, filetype=file_ext) as f:
if f.is_pdf: return f.tobytes()
return f.convert_to_pdf()
file_ext = filetype.guess_extension(file_bytes)
if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
if file_ext == 'pdf':
return file_bytes
elif file_ext in ['jpg', 'png']:
with fitz.open(stream=file_bytes, filetype=file_ext) as f:
return f.convert_to_pdf()
else:
temp_file.write_bytes(file_bytes)
convert_file_to_pdf(temp_file, temp_dir)
return temp_file.with_suffix('.pdf').read_bytes()
else:
raise Exception('Unsupported file format')
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment