Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
chat_demo
Commits
405b3897
Commit
405b3897
authored
Aug 06, 2024
by
chenych
Browse files
Add stream client
parent
00f38043
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
100 deletions
+29
-100
llm_service/client.py
llm_service/client.py
+15
-8
llm_service/client_stream.py
llm_service/client_stream.py
+0
-74
llm_service/inferencer.py
llm_service/inferencer.py
+14
-18
No files found.
llm_service/client.py
View file @
405b3897
...
@@ -2,6 +2,7 @@ import json
...
@@ -2,6 +2,7 @@ import json
import
argparse
import
argparse
import
requests
import
requests
import
configparser
import
configparser
from
typing
import
Iterable
,
List
from
typing
import
Iterable
,
List
...
@@ -57,18 +58,24 @@ if __name__ == "__main__":
...
@@ -57,18 +58,24 @@ if __name__ == "__main__":
api_url
=
f
"http://localhost:8888/
{
func
}
"
api_url
=
f
"http://localhost:8888/
{
func
}
"
if
stream_chat
:
if
stream_chat
:
response
=
requests
.
get
(
api_url
,
headers
=
headers
,
data
=
json_str
.
encode
(
headers
=
{
"utf-8"
),
verify
=
False
,
stream
=
stream_chat
)
"Content-Type"
:
"text/event-stream"
,
"Cache-Control"
:
"no-cache"
,
"Connection"
:
"keep-alive"
}
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
data
=
json_str
.
encode
(
"utf-8"
),
verify
=
False
,
stream
=
True
)
num_printed_lines
=
0
num_printed_lines
=
0
for
h
in
get_streaming_response
(
response
):
for
h
in
get_streaming_response
(
response
):
clear_line
(
num_printed_lines
)
# clear_line(num_printed_lines)
num_printed_lines
=
0
for
i
,
line
in
enumerate
(
h
):
for
i
,
line
in
enumerate
(
h
):
num_printed_lines
+=
1
num_printed_lines
+=
1
print
(
f
"Beam candidate
{
i
}
:
{
line
!
r
}
"
,
flush
=
True
)
print
(
f
"
{
line
!
r
}
"
,
flush
=
True
)
else
:
else
:
response
=
requests
.
get
(
api_url
,
headers
=
headers
,
data
=
json_str
.
encode
(
headers
=
{
"Content-Type"
:
"application/json"
}
"utf-8"
),
verify
=
False
,
stream
=
stream_chat
)
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
data
=
json_str
.
encode
(
"utf-8"
),
verify
=
False
)
output
=
get_response
(
response
)
output
=
get_response
(
response
)
for
i
,
line
in
enumerate
(
output
):
for
i
,
line
in
enumerate
(
output
):
print
(
f
"Beam candidate
{
i
}
:
{
line
!
r
}
"
,
flush
=
True
)
print
(
f
"Beam candidate
{
i
}
:
{
line
!
r
}
"
,
flush
=
True
)
\ No newline at end of file
llm_service/client_stream.py
deleted
100644 → 0
View file @
00f38043
"""Example Python client for vllm.entrypoints.api_server"""
import
argparse
import
json
from
typing
import
Iterable
,
List
import
requests
def
clear_line
(
n
:
int
=
1
)
->
None
:
LINE_UP
=
'
\033
[1A'
LINE_CLEAR
=
'
\x1b
[2K'
for
_
in
range
(
n
):
print
(
LINE_UP
,
end
=
LINE_CLEAR
,
flush
=
True
)
def
post_http_request
(
query
:
str
,
api_url
:
str
,
n
:
int
=
1
,
stream
:
bool
=
False
)
->
requests
.
Response
:
headers
=
{
"User-Agent"
:
"Test Client"
}
pload
=
{
"query"
:
query
,
"n"
:
n
,
"use_beam_search"
:
True
,
"temperature"
:
0.0
,
"max_tokens"
:
16
,
"stream"
:
stream
,
}
response
=
requests
.
post
(
api_url
,
headers
=
headers
,
json
=
pload
,
stream
=
True
)
return
response
def
get_streaming_response
(
response
:
requests
.
Response
)
->
Iterable
[
List
[
str
]]:
for
chunk
in
response
.
iter_lines
(
chunk_size
=
8192
,
decode_unicode
=
False
,
delimiter
=
b
"
\0
"
):
if
chunk
:
data
=
json
.
loads
(
chunk
.
decode
(
"utf-8"
))
output
=
data
[
"text"
]
yield
output
def
get_response
(
response
:
requests
.
Response
)
->
List
[
str
]:
data
=
json
.
loads
(
response
.
content
)
output
=
data
[
"text"
]
return
output
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8888
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--query"
,
type
=
str
,
default
=
"San Francisco is a"
)
parser
.
add_argument
(
"--stream"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
query
=
args
.
query
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
n
=
args
.
n
stream
=
args
.
stream
print
(
f
"Prompt:
{
query
!
r
}
\n
"
,
flush
=
True
)
response
=
post_http_request
(
query
,
api_url
,
n
,
stream
)
if
stream
:
num_printed_lines
=
0
for
h
in
get_streaming_response
(
response
):
clear_line
(
num_printed_lines
)
num_printed_lines
=
0
for
i
,
line
in
enumerate
(
h
):
num_printed_lines
+=
1
print
(
f
"Beam candidate
{
i
}
:
{
line
!
r
}
"
,
flush
=
True
)
else
:
output
=
get_response
(
response
)
for
i
,
line
in
enumerate
(
output
):
print
(
f
"Beam candidate
{
i
}
:
{
line
!
r
}
"
,
flush
=
True
)
\ No newline at end of file
llm_service/inferencer.py
View file @
405b3897
...
@@ -11,7 +11,6 @@ from aiohttp import web
...
@@ -11,7 +11,6 @@ from aiohttp import web
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
COMMON
=
{
COMMON
=
{
"<光合组织登记网址>"
:
"https://www.hieco.com.cn/partner?from=timeline"
,
"<光合组织登记网址>"
:
"https://www.hieco.com.cn/partner?from=timeline"
,
"<官网>"
:
"https://www.sugon.com/after_sale/policy?sh=1"
,
"<官网>"
:
"https://www.sugon.com/after_sale/policy?sh=1"
,
...
@@ -259,12 +258,10 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params):
...
@@ -259,12 +258,10 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params):
def
vllm_inference_stream
(
bind_port
,
model
,
tokenizer
,
sampling_params
):
def
vllm_inference_stream
(
bind_port
,
model
,
tokenizer
,
sampling_params
):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
from
typing
import
AsyncGenerator
from
fastapi.responses
import
StreamingResponse
async
def
inference
(
request
):
async
def
inference
(
request
):
input_json
=
await
request
.
json
()
input_json
=
await
request
.
json
()
prompt
=
input_json
[
'query'
]
prompt
=
input_json
[
'query'
]
# history = input_json['history']
# history = input_json['history']
...
@@ -272,28 +269,27 @@ def vllm_inference_stream(bind_port, model, tokenizer, sampling_params):
...
@@ -272,28 +269,27 @@ def vllm_inference_stream(bind_port, model, tokenizer, sampling_params):
logger
.
info
(
"****************** use vllm ******************"
)
logger
.
info
(
"****************** use vllm ******************"
)
## generate template
## generate template
input_text
=
tokenizer
.
apply_chat_template
(
input_text
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
logger
.
info
(
f
"The input_text is
{
input_text
}
"
)
logger
.
info
(
f
"The input_text is
{
input_text
}
"
)
assert
model
is
not
None
assert
model
is
not
None
request_id
=
str
(
uuid
.
uuid4
().
hex
)
request_id
=
str
(
uuid
.
uuid4
().
hex
)
results_generator
=
model
.
generate
(
input_text
,
sampling_params
=
sampling_params
,
request_id
=
request_id
)
results_generator
=
model
.
generate
(
input_text
,
sampling_params
=
sampling_params
,
request_id
=
request_id
)
# Streaming case
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
logger
.
info
(
"****************** in stream chat ******************"
)
# final_output = None
response
=
web
.
StreamResponse
()
logger
.
info
(
"****************** in stream_results *****************"
)
await
response
.
prepare
(
request
)
async
for
request_output
in
results_generator
:
text_outputs
=
None
# final_output = request_output
async
for
request_output
in
results_generator
:
text_outputs
=
[
output
.
text
for
output
in
request_output
.
outputs
]
prompt
=
request_output
.
prompt
ret
=
{
"text"
:
text_outputs
}
text_outputs
=
[
output
.
text
for
output
in
request_output
.
outputs
]
print
(
ret
)
ret
=
{
"text"
:
text_outputs
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
await
response
.
write
((
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
))
response
.
write_eof
()
logger
.
info
(
"****************** in chat stream *****************"
)
return
response
return
StreamingResponse
(
stream_results
())
app
=
web
.
Application
()
app
=
web
.
Application
()
app
.
add_routes
([
web
.
ge
t
(
'/vllm_inference_stream'
,
inference
)])
app
.
add_routes
([
web
.
pos
t
(
'/vllm_inference_stream'
,
inference
)])
web
.
run_app
(
app
,
host
=
'0.0.0.0'
,
port
=
bind_port
)
web
.
run_app
(
app
,
host
=
'0.0.0.0'
,
port
=
bind_port
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment