Commit 8ffec57d authored by zhougaofeng's avatar zhougaofeng
Browse files

Update pdf_server.py

parent 7a846eee
...@@ -28,14 +28,13 @@ method = 'auto' ...@@ -28,14 +28,13 @@ method = 'auto'
logger.add("parse.log", rotation="10 MB", level="INFO", logger.add("parse.log", rotation="10 MB", level="INFO",
format="{time} {level} {message}", encoding='utf-8', enqueue=True) format="{time} {level} {message}", encoding='utf-8', enqueue=True)
config_path = None config_path = None
ocr_client = None
ocr_status = None ocr_status = None
custom_model = None custom_model = None
class ocrRequest(BaseModel): class ocrRequest(BaseModel):
path: str path: str
output_dir: str output_dir: str
config_path: str
class ocrResponse(BaseModel): class ocrResponse(BaseModel):
status_code: int status_code: int
...@@ -66,42 +65,44 @@ def parse_args(): ...@@ -66,42 +65,44 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
default='./magic_pdf/config.ini') default='/home/practice/magic_pdf-main/magic_pdf/config.ini')
args = parser.parse_args() args = parser.parse_args()
return args return args
def setup_environment(args):
global config_path,ocr_client,compress_image,ocr_status
def ocr_pdf_serve(args: str):
os.environ["CUDA_VISIBLE_DEVICES"] = args.dcu_id os.environ["CUDA_VISIBLE_DEVICES"] = args.dcu_id
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_path) config.read(args.config_path)
# host = config.get('server', 'pdf_host')
# port = int(config.get('server', 'pdf_port'))
vllm_able = config.get('vllm', 'vllm_able')
if vllm_able:
from magic_pdf.dict2md.ocr_vllm_client import PredictClient, compress_image
else:
from magic_pdf.dict2md.ocr_client import PredictClient,compress_image
pdf_server = config.get('server', 'pdf_server').split('://')[1]
host, port = pdf_server.split(':')[0], int(pdf_server.split(':')[1])
global config_path
config_path = args.config_path config_path = args.config_path
pdf_server = config.get('server', 'pdf_server').split('://')[1]
ocr_server = config.get('server', 'ocr_server') ocr_server = config.get('server', 'ocr_server')
vllm_able = config.get('vllm', 'vllm_able')
PredictClient, compress_image = import_ocr_client(vllm_able)
host, port = pdf_server.split(':')[0], int(pdf_server.split(':')[1])
ocr_client = PredictClient(ocr_server) ocr_client = PredictClient(ocr_server)
global ocr_status
ocr_status = ocr_client.check_health() ocr_status = ocr_client.check_health()
return host, port
ocr = True def import_ocr_client(vllm_able: bool):
show_log = False """
model_manager = ModelSingleton() 根据配置动态加载 OCR 客户端模块。
global custom_model """
custom_model = model_manager.get_model(ocr, show_log) if vllm_able:
from magic_pdf.dict2md.ocr_vllm_client import PredictClient, compress_image
else:
from magic_pdf.dict2md.ocr_client import PredictClient, compress_image
return PredictClient, compress_image
def ocr_pdf_serve(args: str):
global custom_model
host, port = setup_environment(args)
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr=True, show_log=False)
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
@app.get("/health") @app.get("/health")
...@@ -115,7 +116,6 @@ async def pdf_ocr(request: ocrRequest): ...@@ -115,7 +116,6 @@ async def pdf_ocr(request: ocrRequest):
model_config.__model_mode__ = 'full' model_config.__model_mode__ = 'full'
output_dir = request.output_dir output_dir = request.output_dir
path = request.path path = request.path
#config_path = request.config_path
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
debug_able = False debug_able = False
start_page_id = 0 start_page_id = 0
...@@ -164,22 +164,12 @@ async def pdf_ocr(request: ocrRequest): ...@@ -164,22 +164,12 @@ async def pdf_ocr(request: ocrRequest):
@app.post("/ofd_ocr") @app.post("/ofd_ocr")
async def ofd_ocr(request: ocrRequest): async def ofd_ocr(request: ocrRequest):
try: try:
# 读取配置文件
config = configparser.ConfigParser()
config.read(request.config_path)
url = config.get('server', 'ocr_server')
pdf_server = config.get('server', 'pdf_server')
# 创建客户端
client = PredictClient(url)
# pdf_ocr = ocrPdfClient(pdf_server)
# 确保输出目录存在 # 确保输出目录存在
os.makedirs(request.output_dir, exist_ok=True) os.makedirs(request.output_dir, exist_ok=True)
# 判断 OFD 是否为发票 # 判断 OFD 是否为发票
# logger.info(f'正在判断ofd文件类型') # logger.info(f'正在判断ofd文件类型')
check_res,ofd_imgs,pdfbytes = check_ofd(request.path,client,request.output_dir) check_res,ofd_imgs,pdfbytes = check_ofd(request.path,ocr_client,request.output_dir)
text = '提取图中的文字信息,并以json格式返回' text = '提取图中的文字信息,并以json格式返回'
...@@ -192,7 +182,7 @@ async def ofd_ocr(request: ocrRequest): ...@@ -192,7 +182,7 @@ async def ofd_ocr(request: ocrRequest):
# 如果是发票,进行 OCR 识别 # 如果是发票,进行 OCR 识别
for ofd_img in ofd_imgs: for ofd_img in ofd_imgs:
compress_image(ofd_img) compress_image(ofd_img)
res = client.predict(ofd_img, text) res = ocr_client.predict(ofd_img, text)
res = json_to_txt(res) res = json_to_txt(res)
res = decode_html_entities(res) res = decode_html_entities(res)
ofd_txts += res + '\n' ofd_txts += res + '\n'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment