Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
magic_pdf
Commits
8ffec57d
Commit
8ffec57d
authored
Nov 25, 2024
by
zhougaofeng
Browse files
Update pdf_server.py
parent
7a846eee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
35 deletions
+25
-35
magic_pdf/tools/pdf_server.py
magic_pdf/tools/pdf_server.py
+25
-35
No files found.
magic_pdf/tools/pdf_server.py
View file @
8ffec57d
...
...
@@ -28,14 +28,13 @@ method = 'auto'
logger
.
add
(
"parse.log"
,
rotation
=
"10 MB"
,
level
=
"INFO"
,
format
=
"{time} {level} {message}"
,
encoding
=
'utf-8'
,
enqueue
=
True
)
config_path
=
None
ocr_client
=
None
ocr_status
=
None
custom_model
=
None
class
ocrRequest
(
BaseModel
):
path
:
str
output_dir
:
str
config_path
:
str
class
ocrResponse
(
BaseModel
):
status_code
:
int
...
...
@@ -66,42 +65,44 @@ def parse_args():
)
parser
.
add_argument
(
'--config_path'
,
default
=
'
.
/magic_pdf/config.ini'
)
default
=
'
/home/practice/magic_pdf-main
/magic_pdf/config.ini'
)
args
=
parser
.
parse_args
()
return
args
def
ocr_pdf_serve
(
args
:
str
):
def
setup_environment
(
args
):
global
config_path
,
ocr_client
,
compress_image
,
ocr_status
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
args
.
dcu_id
config
=
configparser
.
ConfigParser
()
config
.
read
(
args
.
config_path
)
# host = config.get('server', 'pdf_host')
# port = int(config.get('server', 'pdf_port'))
vllm_able
=
config
.
get
(
'vllm'
,
'vllm_able'
)
if
vllm_able
:
from
magic_pdf.dict2md.ocr_vllm_client
import
PredictClient
,
compress_image
else
:
from
magic_pdf.dict2md.ocr_client
import
PredictClient
,
compress_image
pdf_server
=
config
.
get
(
'server'
,
'pdf_server'
).
split
(
'://'
)[
1
]
host
,
port
=
pdf_server
.
split
(
':'
)[
0
],
int
(
pdf_server
.
split
(
':'
)[
1
])
global
config_path
config_path
=
args
.
config_path
pdf_server
=
config
.
get
(
'server'
,
'pdf_server'
).
split
(
'://'
)[
1
]
ocr_server
=
config
.
get
(
'server'
,
'ocr_server'
)
vllm_able
=
config
.
get
(
'vllm'
,
'vllm_able'
)
PredictClient
,
compress_image
=
import_ocr_client
(
vllm_able
)
host
,
port
=
pdf_server
.
split
(
':'
)[
0
],
int
(
pdf_server
.
split
(
':'
)[
1
])
ocr_client
=
PredictClient
(
ocr_server
)
global
ocr_status
ocr_status
=
ocr_client
.
check_health
()
return
host
,
port
ocr
=
True
show_log
=
False
model_manager
=
ModelSingleton
()
global
custom_model
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
)
def
import_ocr_client
(
vllm_able
:
bool
):
"""
根据配置动态加载 OCR 客户端模块。
"""
if
vllm_able
:
from
magic_pdf.dict2md.ocr_vllm_client
import
PredictClient
,
compress_image
else
:
from
magic_pdf.dict2md.ocr_client
import
PredictClient
,
compress_image
return
PredictClient
,
compress_image
def
ocr_pdf_serve
(
args
:
str
):
global
custom_model
host
,
port
=
setup_environment
(
args
)
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
=
True
,
show_log
=
False
)
uvicorn
.
run
(
app
,
host
=
host
,
port
=
port
)
@
app
.
get
(
"/health"
)
...
...
@@ -115,7 +116,6 @@ async def pdf_ocr(request: ocrRequest):
model_config
.
__model_mode__
=
'full'
output_dir
=
request
.
output_dir
path
=
request
.
path
#config_path = request.config_path
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
debug_able
=
False
start_page_id
=
0
...
...
@@ -164,22 +164,12 @@ async def pdf_ocr(request: ocrRequest):
@
app
.
post
(
"/ofd_ocr"
)
async
def
ofd_ocr
(
request
:
ocrRequest
):
try
:
# 读取配置文件
config
=
configparser
.
ConfigParser
()
config
.
read
(
request
.
config_path
)
url
=
config
.
get
(
'server'
,
'ocr_server'
)
pdf_server
=
config
.
get
(
'server'
,
'pdf_server'
)
# 创建客户端
client
=
PredictClient
(
url
)
# pdf_ocr = ocrPdfClient(pdf_server)
# 确保输出目录存在
os
.
makedirs
(
request
.
output_dir
,
exist_ok
=
True
)
# 判断 OFD 是否为发票
# logger.info(f'正在判断ofd文件类型')
check_res
,
ofd_imgs
,
pdfbytes
=
check_ofd
(
request
.
path
,
client
,
request
.
output_dir
)
check_res
,
ofd_imgs
,
pdfbytes
=
check_ofd
(
request
.
path
,
ocr_
client
,
request
.
output_dir
)
text
=
'提取图中的文字信息,并以json格式返回'
...
...
@@ -192,7 +182,7 @@ async def ofd_ocr(request: ocrRequest):
# 如果是发票,进行 OCR 识别
for
ofd_img
in
ofd_imgs
:
compress_image
(
ofd_img
)
res
=
client
.
predict
(
ofd_img
,
text
)
res
=
ocr_
client
.
predict
(
ofd_img
,
text
)
res
=
json_to_txt
(
res
)
res
=
decode_html_entities
(
res
)
ofd_txts
+=
res
+
'
\n
'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment