Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
magic_pdf
Commits
922eeb6e
Commit
922eeb6e
authored
Oct 24, 2024
by
zhougaofeng
Browse files
Update magic_pdf/dict2md/ocr_vllm_client.py, magic_pdf/dict2md/ocr_vllm_server.py files
parent
0d0edaf5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
473 additions
and
0 deletions
+473
-0
magic_pdf/dict2md/ocr_vllm_client.py
magic_pdf/dict2md/ocr_vllm_client.py
+183
-0
magic_pdf/dict2md/ocr_vllm_server.py
magic_pdf/dict2md/ocr_vllm_server.py
+290
-0
No files found.
magic_pdf/dict2md/ocr_vllm_client.py
0 → 100644
View file @
922eeb6e
import
os
import
json
import
requests
from
loguru
import
logger
import
argparse
import
time
from
PIL
import
Image
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--url'
,
default
=
'http://0.0.0.0:6088'
,
help
=
'The URL of the server'
)
parser
.
add_argument
(
'--image_path'
,
default
=
'/path/to/your/image.png'
,
help
=
'Path to the image file'
)
parser
.
add_argument
(
'--text'
,
default
=
"描述你在图片中看到的内容"
,
help
=
'Text input for the model'
)
args
=
parser
.
parse_args
()
return
args
def
parse_text
(
text
):
lines
=
text
.
split
(
"
\n
"
)
lines
=
[
line
for
line
in
lines
if
line
.
strip
()
!=
""
]
# 去除空行
count
=
0
parsed_lines
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
"```"
in
line
:
count
+=
1
items
=
line
.
split
(
"`"
)
if
count
%
2
==
1
:
# 开始代码块
parsed_lines
.
append
(
f
'<pre><code class="language-
{
items
[
-
1
]
}
">'
)
else
:
# 结束代码块
parsed_lines
.
append
(
f
"</code></pre>"
)
else
:
if
i
>
0
and
count
%
2
==
1
:
# 转义代码块内的特殊字符
line
=
line
.
replace
(
"`"
,
r
"\`"
)
line
=
line
.
replace
(
"<"
,
"<"
)
line
=
line
.
replace
(
">"
,
">"
)
line
=
line
.
replace
(
" "
,
" "
)
line
=
line
.
replace
(
"*"
,
"*"
)
line
=
line
.
replace
(
"_"
,
"_"
)
line
=
line
.
replace
(
"-"
,
"-"
)
line
=
line
.
replace
(
"."
,
"."
)
line
=
line
.
replace
(
"!"
,
"!"
)
line
=
line
.
replace
(
"("
,
"("
)
line
=
line
.
replace
(
")"
,
")"
)
line
=
line
.
replace
(
"$"
,
"$"
)
# 使用空格连接行
if
parsed_lines
:
parsed_lines
[
-
1
]
+=
" "
+
line
else
:
parsed_lines
.
append
(
line
)
text
=
""
.
join
(
parsed_lines
)
return
text
def
unparse_text
(
parsed_text
):
in_code_block
=
False
lines
=
parsed_text
.
split
(
"
\n
"
)
unparsed_lines
=
[]
for
line
in
lines
:
if
"<pre><code"
in
line
:
in_code_block
=
True
# 移除开始标签
line
=
line
.
split
(
">"
,
1
)[
1
]
elif
"</code></pre>"
in
line
:
in_code_block
=
False
# 移除结束标签
line
=
line
.
rsplit
(
"<"
,
1
)[
0
]
# 反转 HTML 实体
line
=
line
.
replace
(
"<"
,
"<"
)
line
=
line
.
replace
(
">"
,
">"
)
line
=
line
.
replace
(
" "
,
" "
)
line
=
line
.
replace
(
"*"
,
"*"
)
line
=
line
.
replace
(
"_"
,
"_"
)
line
=
line
.
replace
(
"-"
,
"-"
)
line
=
line
.
replace
(
"."
,
"."
)
line
=
line
.
replace
(
"!"
,
"!"
)
line
=
line
.
replace
(
"("
,
"("
)
line
=
line
.
replace
(
")"
,
")"
)
line
=
line
.
replace
(
"$"
,
"$"
)
# 如果在代码块内,还原反斜杠转义
if
in_code_block
:
line
=
line
.
replace
(
r
"\`"
,
"`"
)
unparsed_lines
.
append
(
line
)
# 合并所有行
unparsed_text
=
"
\n
"
.
join
(
unparsed_lines
)
return
unparsed_text
def
compress_image
(
image_path
,
max_size
=
(
512
,
512
)):
img
=
Image
.
open
(
image_path
)
width
,
height
=
img
.
size
aspect_ratio
=
width
/
height
if
width
>
max_size
[
0
]
or
height
>
max_size
[
1
]:
if
width
>
height
:
new_width
=
max_size
[
0
]
new_height
=
int
(
new_width
/
aspect_ratio
)
else
:
new_height
=
max_size
[
1
]
new_width
=
int
(
new_height
*
aspect_ratio
)
img
=
img
.
resize
((
new_width
,
new_height
),
Image
.
LANCZOS
)
img
.
save
(
image_path
,
optimize
=
True
,
quality
=
80
)
class
PredictClient
:
def
__init__
(
self
,
api_url
):
self
.
api_url
=
api_url
def
check_health
(
self
):
health_check_url
=
f
'
{
self
.
api_url
}
/health'
try
:
response
=
requests
.
get
(
health_check_url
)
if
response
.
status_code
==
200
:
logger
.
info
(
"Server is healthy and ready to process requests."
)
return
True
else
:
logger
.
error
(
f
'Server health check failed with status code:
{
response
.
status_code
}
'
)
return
False
except
requests
.
exceptions
.
RequestException
as
e
:
logger
.
error
(
f
'Health check request failed:
{
e
}
'
)
return
False
def
predict
(
self
,
image_path
:
str
,
text
:
str
):
payload
=
{
"image_path"
:
image_path
,
"text"
:
text
}
headers
=
{
'Content-Type'
:
'application/json'
}
response
=
requests
.
post
(
f
"
{
self
.
api_url
}
/predict"
,
json
=
payload
,
headers
=
headers
)
if
response
.
status_code
==
200
:
result
=
response
.
json
()
return
result
.
get
(
'Generated Text'
,
''
)
else
:
raise
Exception
(
f
"Predict API request failed with status code
{
response
.
status_code
}
"
)
def
main
():
args
=
parse_args
()
client
=
PredictClient
(
args
.
url
)
try
:
start_time
=
time
.
time
()
# 记录开始时间
# 压缩图片
compress_image
(
args
.
image_path
)
generated_text
=
client
.
predict
(
args
.
image_path
,
parse_text
(
args
.
text
))
end_time
=
time
.
time
()
# 记录结束时间
elapsed_time
=
end_time
-
start_time
# 计算运行时间
if
generated_text
:
clean_text
=
unparse_text
(
generated_text
)
# 解析生成的文本
logger
.
info
(
f
"Image Path:
{
args
.
image_path
}
"
)
logger
.
info
(
f
"Generated Text:
{
clean_text
}
"
)
logger
.
info
(
f
"耗时为:
{
elapsed_time
}
秒"
)
# 打印运行时间
else
:
logger
.
warning
(
"Received empty generated text."
)
except
requests
.
exceptions
.
RequestException
as
e
:
logger
.
error
(
f
"Error while making request to predict service:
{
e
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Unexpected error occurred:
{
e
}
"
)
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
magic_pdf/dict2md/ocr_vllm_server.py
0 → 100644
View file @
922eeb6e
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree
import
configparser
import
copy
import
re
import
gc
import
torch
from
argparse
import
ArgumentParser
from
threading
import
Thread
from
qwen_vl_utils
import
process_vision_info
from
transformers
import
AutoProcessor
from
vllm
import
LLM
,
SamplingParams
import
os
from
loguru
import
logger
from
fastapi
import
FastAPI
from
pydantic
import
BaseModel
from
typing
import
Optional
logger
.
add
(
"parse.log"
,
rotation
=
"10 MB"
,
level
=
"INFO"
,
format
=
"{time} {level} {message}"
,
encoding
=
'utf-8'
,
enqueue
=
True
)
app
=
FastAPI
()
DEFAULT_CKPT_PATH
=
'/home/practice/model/Qwen2-VL-7B-Instruct'
REVISION
=
'v1.0.4'
BOX_TAG_PATTERN
=
r
"<box>([\s\S]*?)</box>"
PUNCTUATION
=
"!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
def
get_args
():
parser
=
ArgumentParser
()
parser
.
add_argument
(
'-c'
,
'--checkpoint_path'
,
type
=
str
,
default
=
DEFAULT_CKPT_PATH
,
help
=
'Checkpoint name or path, default to %(default)r'
)
parser
.
add_argument
(
'--cpu_only'
,
action
=
'store_true'
,
help
=
'Run demo with CPU only'
)
parser
.
add_argument
(
'--flash_attn2'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Enable flash_attention_2 when loading the model.'
)
parser
.
add_argument
(
'--share'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Create a publicly shareable link for the interface.'
)
parser
.
add_argument
(
'--inbrowser'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Automatically launch the interface in a new tab on the default browser.'
)
parser
.
add_argument
(
'--gpu_nums'
,
type
=
int
,
default
=
1
,
help
=
'Number of GPUs to use for tensor parallelism.'
)
parser
.
add_argument
(
'--dcu_id'
,
type
=
str
,
default
=
None
,
help
=
'Specify the GPU ID to load the model onto.'
)
parser
.
add_argument
(
'--config_path'
,
default
=
'/home/practice/magic_pdf-main/magic_pdf/config.ini'
,
)
args
=
parser
.
parse_args
()
return
args
def
load_model_processor
(
args
):
if
args
.
cpu_only
:
device
=
'cpu'
else
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
dcu_id
print
(
f
"Visible CUDA devices:
{
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
}
"
)
llm
=
LLM
(
model
=
args
.
checkpoint_path
,
limit_mm_per_prompt
=
{
"image"
:
10
,
"video"
:
10
},
trust_remote_code
=
True
,
tensor_parallel_size
=
args
.
gpu_nums
,
# 用args.gpu_nums根据实际情况调整
dtype
=
'float16'
,
# 或者 'bfloat16'
)
processor
=
AutoProcessor
.
from_pretrained
(
args
.
checkpoint_path
)
return
llm
,
processor
def
parse_text
(
text
):
lines
=
text
.
split
(
"
\n
"
)
lines
=
[
line
for
line
in
lines
if
line
.
strip
()
!=
""
]
# 去除空行
count
=
0
parsed_lines
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
"```"
in
line
:
count
+=
1
items
=
line
.
split
(
"`"
)
if
count
%
2
==
1
:
# 开始代码块
parsed_lines
.
append
(
f
'<pre><code class="language-
{
items
[
-
1
]
}
">'
)
else
:
# 结束代码块
parsed_lines
.
append
(
f
"</code></pre>"
)
else
:
if
i
>
0
and
count
%
2
==
1
:
# 转义代码块内的特殊字符
line
=
line
.
replace
(
"`"
,
r
"\`"
)
line
=
line
.
replace
(
"<"
,
"<"
)
line
=
line
.
replace
(
">"
,
">"
)
line
=
line
.
replace
(
" "
,
" "
)
line
=
line
.
replace
(
"*"
,
"*"
)
line
=
line
.
replace
(
"_"
,
"_"
)
line
=
line
.
replace
(
"-"
,
"-"
)
line
=
line
.
replace
(
"."
,
"."
)
line
=
line
.
replace
(
"!"
,
"!"
)
line
=
line
.
replace
(
"("
,
"("
)
line
=
line
.
replace
(
")"
,
")"
)
line
=
line
.
replace
(
"$"
,
"$"
)
# 使用空格连接行
if
parsed_lines
:
parsed_lines
[
-
1
]
+=
" "
+
line
else
:
parsed_lines
.
append
(
line
)
text
=
""
.
join
(
parsed_lines
)
return
text
def
unparse_text
(
parsed_text
):
in_code_block
=
False
lines
=
parsed_text
.
split
(
"
\n
"
)
unparsed_lines
=
[]
for
line
in
lines
:
if
"<pre><code"
in
line
:
in_code_block
=
True
# 移除开始标签
line
=
line
.
split
(
">"
,
1
)[
1
]
elif
"</code></pre>"
in
line
:
in_code_block
=
False
# 移除结束标签
line
=
line
.
rsplit
(
"<"
,
1
)[
0
]
# 反转 HTML 实体
line
=
line
.
replace
(
"<"
,
"<"
)
line
=
line
.
replace
(
">"
,
">"
)
line
=
line
.
replace
(
" "
,
" "
)
line
=
line
.
replace
(
"*"
,
"*"
)
line
=
line
.
replace
(
"_"
,
"_"
)
line
=
line
.
replace
(
"-"
,
"-"
)
line
=
line
.
replace
(
"."
,
"."
)
line
=
line
.
replace
(
"!"
,
"!"
)
line
=
line
.
replace
(
"("
,
"("
)
line
=
line
.
replace
(
")"
,
")"
)
line
=
line
.
replace
(
"$"
,
"$"
)
# 如果在代码块内,还原反斜杠转义
if
in_code_block
:
line
=
line
.
replace
(
r
"\`"
,
"`"
)
unparsed_lines
.
append
(
line
)
# 合并所有行
unparsed_text
=
"
\n
"
.
join
(
unparsed_lines
)
return
unparsed_text
def
remove_image_special
(
text
):
text
=
text
.
replace
(
'<ref>'
,
''
).
replace
(
'</ref>'
,
''
)
return
re
.
sub
(
r
'<box>.*?(</box>|$)'
,
''
,
text
)
def
is_video_file
(
filename
):
video_extensions
=
[
'.mp4'
,
'.avi'
,
'.mkv'
,
'.mov'
,
'.wmv'
,
'.flv'
,
'.webm'
,
'.mpeg'
]
return
any
(
filename
.
lower
().
endswith
(
ext
)
for
ext
in
video_extensions
)
def
transform_messages
(
original_messages
):
transformed_messages
=
[]
for
message
in
original_messages
:
new_content
=
[]
for
item
in
message
[
'content'
]:
if
'image'
in
item
:
new_item
=
{
'type'
:
'image'
,
'image'
:
item
[
'image'
]}
elif
'text'
in
item
:
new_item
=
{
'type'
:
'text'
,
'text'
:
item
[
'text'
]}
elif
'video'
in
item
:
new_item
=
{
'type'
:
'video'
,
'video'
:
item
[
'video'
]}
else
:
continue
new_content
.
append
(
new_item
)
new_message
=
{
'role'
:
message
[
'role'
],
'content'
:
new_content
}
transformed_messages
.
append
(
new_message
)
return
transformed_messages
def
_gc
():
gc
.
collect
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
def
call_local_model
(
llm
,
processor
,
messages
):
messages
=
transform_messages
(
messages
)
text
=
processor
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
image_inputs
,
video_inputs
=
process_vision_info
(
messages
)
mm_data
=
{}
if
image_inputs
is
not
None
:
mm_data
[
"image"
]
=
image_inputs
if
video_inputs
is
not
None
:
mm_data
[
"video"
]
=
video_inputs
llm_inputs
=
{
"prompt"
:
text
,
"multi_modal_data"
:
mm_data
,
}
sampling_params
=
SamplingParams
(
temperature
=
0.1
,
top_p
=
0.001
,
repetition_penalty
=
1.05
,
max_tokens
=
256
,
stop_token_ids
=
[],
)
outputs
=
llm
.
generate
([
llm_inputs
],
sampling_params
=
sampling_params
)
generated_text
=
outputs
[
0
].
outputs
[
0
].
text
yield
parse_text
(
generated_text
)
def
create_predict_fn
(
llm
,
processor
):
def
predict
(
_chatbot
,
task_history
):
chat_query
=
_chatbot
[
-
1
][
0
]
query
=
task_history
[
-
1
][
0
]
if
len
(
chat_query
)
==
0
:
_chatbot
.
pop
()
task_history
.
pop
()
return
_chatbot
print
(
'User: '
+
parse_text
(
query
))
history_cp
=
copy
.
deepcopy
(
task_history
)
full_response
=
''
messages
=
[]
content
=
[]
for
q
,
a
in
history_cp
:
if
isinstance
(
q
,
(
tuple
,
list
)):
if
is_video_file
(
q
[
0
]):
content
.
append
({
'video'
:
f
'file://
{
q
[
0
]
}
'
})
else
:
content
.
append
({
'image'
:
f
'file://
{
q
[
0
]
}
'
})
else
:
content
.
append
({
'text'
:
q
})
messages
.
append
({
'role'
:
'user'
,
'content'
:
content
})
messages
.
append
({
'role'
:
'assistant'
,
'content'
:
[{
'text'
:
a
}]})
content
=
[]
messages
.
pop
()
for
response
in
call_local_model
(
llm
,
processor
,
messages
):
_chatbot
[
-
1
]
=
(
parse_text
(
chat_query
),
remove_image_special
(
parse_text
(
response
)))
yield
_chatbot
full_response
=
parse_text
(
response
)
task_history
[
-
1
]
=
(
query
,
full_response
)
print
(
'Qwen-VL-Chat: '
+
unparse_text
(
full_response
))
yield
_chatbot
return
predict
# 启用加载模型
args
=
get_args
()
llm
,
processor
=
load_model_processor
(
args
)
class
Item
(
BaseModel
):
image_path
:
str
text
:
str
@
app
.
get
(
"/health"
)
async
def
health_check
():
return
{
"status"
:
"healthy"
}
@
app
.
post
(
"/predict"
)
async
def
predict
(
item
:
Item
):
messages
=
[
{
'role'
:
'user'
,
'content'
:
[
{
'image'
:
item
.
image_path
},
{
'text'
:
item
.
text
}
]
}
]
generated_text
=
''
for
response
in
call_local_model
(
llm
,
processor
,
messages
):
generated_text
=
unparse_text
(
response
)
_gc
()
return
{
"Generated Text"
:
generated_text
}
if
__name__
==
"__main__"
:
import
uvicorn
args
=
get_args
()
config
=
configparser
.
ConfigParser
()
config
.
read
(
args
.
config_path
)
host
=
config
.
get
(
'server'
,
'ocr_host'
)
port
=
int
(
config
.
get
(
'server'
,
'ocr_port'
))
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment