Commit 8ffec57d authored by zhougaofeng's avatar zhougaofeng
Browse files

Update pdf_server.py

parent 7a846eee
......@@ -28,14 +28,13 @@ method = 'auto'
logger.add("parse.log", rotation="10 MB", level="INFO",
format="{time} {level} {message}", encoding='utf-8', enqueue=True)
config_path = None
ocr_client = None
ocr_status = None
custom_model = None
class ocrRequest(BaseModel):
path: str
output_dir: str
config_path: str
class ocrResponse(BaseModel):
status_code: int
......@@ -66,42 +65,44 @@ def parse_args():
)
parser.add_argument(
'--config_path',
default='./magic_pdf/config.ini')
default='/home/practice/magic_pdf-main/magic_pdf/config.ini')
args = parser.parse_args()
return args
def ocr_pdf_serve(args: str):
def setup_environment(args):
global config_path,ocr_client,compress_image,ocr_status
os.environ["CUDA_VISIBLE_DEVICES"] = args.dcu_id
config = configparser.ConfigParser()
config.read(args.config_path)
# host = config.get('server', 'pdf_host')
# port = int(config.get('server', 'pdf_port'))
vllm_able = config.get('vllm', 'vllm_able')
if vllm_able:
from magic_pdf.dict2md.ocr_vllm_client import PredictClient, compress_image
else:
from magic_pdf.dict2md.ocr_client import PredictClient,compress_image
pdf_server = config.get('server', 'pdf_server').split('://')[1]
host, port = pdf_server.split(':')[0], int(pdf_server.split(':')[1])
global config_path
config_path = args.config_path
pdf_server = config.get('server', 'pdf_server').split('://')[1]
ocr_server = config.get('server', 'ocr_server')
vllm_able = config.get('vllm', 'vllm_able')
PredictClient, compress_image = import_ocr_client(vllm_able)
host, port = pdf_server.split(':')[0], int(pdf_server.split(':')[1])
ocr_client = PredictClient(ocr_server)
global ocr_status
ocr_status = ocr_client.check_health()
return host, port
ocr = True
show_log = False
model_manager = ModelSingleton()
global custom_model
custom_model = model_manager.get_model(ocr, show_log)
def import_ocr_client(vllm_able: bool):
"""
根据配置动态加载 OCR 客户端模块。
"""
if vllm_able:
from magic_pdf.dict2md.ocr_vllm_client import PredictClient, compress_image
else:
from magic_pdf.dict2md.ocr_client import PredictClient, compress_image
return PredictClient, compress_image
def ocr_pdf_serve(args: str):
global custom_model
host, port = setup_environment(args)
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr=True, show_log=False)
uvicorn.run(app, host=host, port=port)
@app.get("/health")
......@@ -115,7 +116,6 @@ async def pdf_ocr(request: ocrRequest):
model_config.__model_mode__ = 'full'
output_dir = request.output_dir
path = request.path
#config_path = request.config_path
os.makedirs(output_dir, exist_ok=True)
debug_able = False
start_page_id = 0
......@@ -164,22 +164,12 @@ async def pdf_ocr(request: ocrRequest):
@app.post("/ofd_ocr")
async def ofd_ocr(request: ocrRequest):
try:
# 读取配置文件
config = configparser.ConfigParser()
config.read(request.config_path)
url = config.get('server', 'ocr_server')
pdf_server = config.get('server', 'pdf_server')
# 创建客户端
client = PredictClient(url)
# pdf_ocr = ocrPdfClient(pdf_server)
# 确保输出目录存在
os.makedirs(request.output_dir, exist_ok=True)
# 判断 OFD 是否为发票
# logger.info(f'正在判断ofd文件类型')
check_res,ofd_imgs,pdfbytes = check_ofd(request.path,client,request.output_dir)
check_res,ofd_imgs,pdfbytes = check_ofd(request.path,ocr_client,request.output_dir)
text = '提取图中的文字信息,并以json格式返回'
......@@ -192,7 +182,7 @@ async def ofd_ocr(request: ocrRequest):
# 如果是发票,进行 OCR 识别
for ofd_img in ofd_imgs:
compress_image(ofd_img)
res = client.predict(ofd_img, text)
res = ocr_client.predict(ofd_img, text)
res = json_to_txt(res)
res = decode_html_entities(res)
ofd_txts += res + '\n'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment