Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
chat_demo
Commits
e08c9060
Commit
e08c9060
authored
Aug 05, 2024
by
chenych
Browse files
Modify codes
parent
f0863458
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
37 deletions
+21
-37
llm_service/inferencer.py
llm_service/inferencer.py
+8
-5
llm_service/vllm_test.py
llm_service/vllm_test.py
+13
-32
No files found.
llm_service/inferencer.py
View file @
e08c9060
...
...
@@ -228,7 +228,6 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
# history = input_json['history']
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
logger
.
info
(
"****************** use vllm ******************"
)
## generate template
input_text
=
tokenizer
.
apply_chat_template
(
...
...
@@ -248,12 +247,15 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
print
(
ret
)
# yield (json.dumps(ret) + "\0").encode("utf-8")
# yield web.json_response({'text': text_outputs})
return
final_output
assert
final_output
is
not
None
return
[
output
.
text
for
output
in
final_output
.
outputs
]
if
stream_chat
:
logger
.
info
(
"****************** in chat stream *****************"
)
# return StreamingResponse(stream_results())
output_text
=
await
stream_results
()
text
=
await
stream_results
()
output_text
=
substitution
(
text
)
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
prompt
,
output_text
,
time
.
time
()
-
start
))
return
web
.
json_response
({
'text'
:
output_text
})
# Non-streaming case
...
...
@@ -269,9 +271,10 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
assert
final_output
is
not
None
text
=
[
output
.
text
for
output
in
final_output
.
outputs
]
end
=
time
.
time
()
output_text
=
substitution
(
text
)
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
prompt
,
text
,
end
-
start
))
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
prompt
,
output_text
,
time
.
time
()
-
start
))
return
web
.
json_response
({
'text'
:
output_text
})
app
=
web
.
Application
()
...
...
llm_service/vllm_test.py
View file @
e08c9060
...
...
@@ -13,8 +13,6 @@ from aiohttp import web
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
SamplingParams
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
COMMON
=
{
"<光合组织登记网址>"
:
"https://www.hieco.com.cn/partner?from=timeline"
,
...
...
@@ -113,24 +111,30 @@ def llm_inference(args):
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
print
(
text
)
assert
model
is
not
None
request_id
=
str
(
uuid
.
uuid4
().
hex
)
## vllm-0.5.0
# results_generator = model.generate(inputs=text, sampling_params=sampling_params, request_id=request_id)
## vllm-0.3.3
results_generator
=
model
.
generate
(
prompt
=
text
,
sampling_params
=
sampling_params
,
request_id
=
request_id
)
results_generator
=
model
.
generate
(
text
,
sampling_params
=
sampling_params
,
request_id
=
request_id
)
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
final_output
=
None
async
for
request_output
in
results_generator
:
final_output
=
request_output
text_outputs
=
[
output
.
text
for
output
in
request_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
}
print
(
ret
)
# yield (json.dumps(ret) + "\0").encode("utf-8")
yield
web
.
json_response
({
'text'
:
text
})
# yield web.json_response({'text': text_outputs})
assert
final_output
is
not
None
return
[
output
.
text
for
output
in
final_output
.
outputs
]
if
stream_chat
:
return
StreamingResponse
(
stream_results
())
logger
.
info
(
"****************** in chat stream *****************"
)
# return StreamingResponse(stream_results())
text
=
await
stream_results
()
output_text
=
substitution
(
text
)
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
prompt
,
output_text
,
time
.
time
()
-
start
))
return
web
.
json_response
({
'text'
:
output_text
})
# Non-streaming case
final_output
=
None
...
...
@@ -153,28 +157,6 @@ def llm_inference(args):
web
.
run_app
(
app
,
host
=
'0.0.0.0'
,
port
=
bind_port
)
def
infer_test
(
args
):
config
=
configparser
.
ConfigParser
()
config
.
read
(
args
.
config_path
)
model_path
=
config
[
'llm'
][
'local_llm_path'
]
use_vllm
=
config
.
getboolean
(
'llm'
,
'use_vllm'
)
tensor_parallel_size
=
config
.
getint
(
'llm'
,
'tensor_parallel_size'
)
stream_chat
=
config
.
getboolean
(
'llm'
,
'stream_chat'
)
logger
.
info
(
f
"Get params: model_path
{
model_path
}
, use_vllm
{
use_vllm
}
, tensor_parallel_size
{
tensor_parallel_size
}
, stream_chat
{
stream_chat
}
"
)
model
,
tokenzier
=
init_model
(
model_path
,
use_vllm
,
tensor_parallel_size
)
llm_infer
=
LLMInference
(
model
,
tokenzier
,
use_vllm
=
use_vllm
)
time_first
=
time
.
time
()
output_text
=
llm_infer
.
chat
(
args
.
query
)
time_second
=
time
.
time
()
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
args
.
query
,
output_text
,
time_second
-
time_first
))
def
set_envs
(
dcu_ids
):
try
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
dcu_ids
...
...
@@ -209,7 +191,6 @@ def main():
args
=
parse_args
()
set_envs
(
args
.
DCU_ID
)
llm_inference
(
args
)
# infer_test(args)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment