Commit 667d2c0d authored by Hui's avatar Hui
Browse files

Adjust projects/multi_gpu/server.py for magic_pdf-1.0.1

parent e41d7be3
...@@ -31,7 +31,7 @@ python server.py ...@@ -31,7 +31,7 @@ python server.py
### 2. 启动客户端 ### 2. 启动客户端
以下代码展示了客户端的使用方式,可根据需求修改配置: 以下代码展示了客户端的使用方式,可根据需求修改配置:
```python ```python
files = ['demo/small_ocr.pdf'] # 替换为文件路径,支持 jpg/jpeg、png、pdf 文件 files = ['demo/small_ocr.pdf'] # 替换为文件路径,支持 pdf、jpg/jpeg、png、doc、docx、ppt、pptx 文件
n_jobs = np.clip(len(files), 1, 8) # 设置并发线程数,此处最大为 8,可根据自身修改 n_jobs = np.clip(len(files), 1, 8) # 设置并发线程数,此处最大为 8,可根据自身修改
results = Parallel(n_jobs, prefer='threads', verbose=10)( results = Parallel(n_jobs, prefer='threads', verbose=10)(
delayed(do_parse)(p) for p in files delayed(do_parse)(p) for p in files
......
...@@ -31,7 +31,7 @@ def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs): ...@@ -31,7 +31,7 @@ def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs):
if __name__ == '__main__': if __name__ == '__main__':
files = ['small_ocr.pdf'] files = ['demo/small_ocr.pdf']
n_jobs = np.clip(len(files), 1, 8) n_jobs = np.clip(len(files), 1, 8)
results = Parallel(n_jobs, prefer='threads', verbose=10)( results = Parallel(n_jobs, prefer='threads', verbose=10)(
delayed(do_parse)(p) for p in files delayed(do_parse)(p) for p in files
......
import os import os
import uuid
import shutil
import tempfile
import gc
import fitz import fitz
import torch import torch
import base64 import base64
import filetype
import litserve as ls import litserve as ls
from uuid import uuid4 from pathlib import Path
from fastapi import HTTPException from fastapi import HTTPException
from filetype import guess_extension from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
from magic_pdf.tools.common import do_parse
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
class MinerUAPI(ls.LitAPI): class MinerUAPI(ls.LitAPI):
def __init__(self, output_dir='/tmp'): def __init__(self, output_dir='/tmp'):
self.output_dir = output_dir self.output_dir = Path(output_dir)
def setup(self, device): def setup(self, device):
if device.startswith('cuda'): if device.startswith('cuda'):
...@@ -27,7 +31,7 @@ class MinerUAPI(ls.LitAPI): ...@@ -27,7 +31,7 @@ class MinerUAPI(ls.LitAPI):
def decode_request(self, request): def decode_request(self, request):
file = request['file'] file = request['file']
file = self.to_pdf(file) file = self.cvt2pdf(file)
opts = request.get('kwargs', {}) opts = request.get('kwargs', {})
opts.setdefault('debug_able', False) opts.setdefault('debug_able', False)
opts.setdefault('parse_method', 'auto') opts.setdefault('parse_method', 'auto')
...@@ -35,9 +39,12 @@ class MinerUAPI(ls.LitAPI): ...@@ -35,9 +39,12 @@ class MinerUAPI(ls.LitAPI):
def predict(self, inputs): def predict(self, inputs):
try: try:
do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1]) pdf_name = str(uuid.uuid4())
return pdf_name output_dir = self.output_dir.joinpath(pdf_name)
do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
return output_dir
except Exception as e: except Exception as e:
shutil.rmtree(output_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally: finally:
self.clean_memory() self.clean_memory()
...@@ -46,21 +53,34 @@ class MinerUAPI(ls.LitAPI): ...@@ -46,21 +53,34 @@ class MinerUAPI(ls.LitAPI):
return {'output_dir': response} return {'output_dir': response}
def clean_memory(self): def clean_memory(self):
import gc
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
gc.collect() gc.collect()
def to_pdf(self, file_base64): def cvt2pdf(self, file_base64):
try: try:
temp_dir = Path(tempfile.mkdtemp())
temp_file = temp_dir.joinpath('tmpfile')
file_bytes = base64.b64decode(file_base64) file_bytes = base64.b64decode(file_base64)
file_ext = guess_extension(file_bytes) file_ext = filetype.guess_extension(file_bytes)
with fitz.open(stream=file_bytes, filetype=file_ext) as f:
if f.is_pdf: return f.tobytes() if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
return f.convert_to_pdf() if file_ext == 'pdf':
return file_bytes
elif file_ext in ['jpg', 'png']:
with fitz.open(stream=file_bytes, filetype=file_ext) as f:
return f.convert_to_pdf()
else:
temp_file.write_bytes(file_bytes)
convert_file_to_pdf(temp_file, temp_dir)
return temp_file.with_suffix('.pdf').read_bytes()
else:
raise Exception('Unsupported file format')
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment