Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
chat_demo
Commits
020c2e2f
Commit
020c2e2f
authored
Aug 02, 2024
by
Rayyyyy
Browse files
Add codes
parent
8f65b603
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
15 deletions
+42
-15
llm_service/client.py
llm_service/client.py
+21
-0
llm_service/inferencer.py
llm_service/inferencer.py
+21
-15
No files found.
llm_service/client.py
0 → 100644
View file @
020c2e2f
import
json
import
argparse
import
requests
parse
=
argparse
.
ArgumentParser
()
parse
.
add_argument
(
'--query'
,
default
=
'请写一首诗'
)
args
=
parse
.
parse_args
()
print
(
args
.
query
)
headers
=
{
"Content-Type"
:
"application/json"
}
data
=
{
"query"
:
args
.
query
,
"history"
:
[]
}
json_str
=
json
.
dumps
(
data
)
response
=
requests
.
post
(
"http://localhost:8888/inference"
,
headers
=
headers
,
data
=
json_str
.
encode
(
"utf-8"
),
verify
=
False
)
str_response
=
response
.
content
.
decode
(
"utf-8"
)
print
(
json
.
loads
(
str_response
))
llm_service/inferencer.py
View file @
020c2e2f
...
@@ -2,11 +2,12 @@ import time
...
@@ -2,11 +2,12 @@ import time
import
os
import
os
import
configparser
import
configparser
import
argparse
import
argparse
import
torch
# import torch
import
asyncio
from
loguru
import
logger
from
loguru
import
logger
from
aiohttp
import
web
from
aiohttp
import
web
from
multiprocessing
import
Value
#
from multiprocessing import Value
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
...
@@ -72,7 +73,6 @@ class LLMInference:
...
@@ -72,7 +73,6 @@ class LLMInference:
sampling_params
,
sampling_params
,
device
:
str
=
'cuda'
,
device
:
str
=
'cuda'
,
use_vllm
:
bool
=
False
,
use_vllm
:
bool
=
False
,
stream_chat
:
bool
=
False
)
->
None
:
)
->
None
:
self
.
device
=
device
self
.
device
=
device
...
@@ -80,7 +80,6 @@ class LLMInference:
...
@@ -80,7 +80,6 @@ class LLMInference:
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
use_vllm
=
use_vllm
self
.
use_vllm
=
use_vllm
self
.
stream_chat
=
stream_chat
def
generate_response
(
self
,
prompt
,
history
=
[]):
def
generate_response
(
self
,
prompt
,
history
=
[]):
print
(
"generate"
)
print
(
"generate"
)
...
@@ -120,8 +119,9 @@ class LLMInference:
...
@@ -120,8 +119,9 @@ class LLMInference:
try
:
try
:
if
self
.
use_vllm
:
if
self
.
use_vllm
:
## vllm
## vllm
logger
.
info
(
"****************** use vllm ******************"
)
prompt_token_ids
=
[
self
.
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
)]
prompt_token_ids
=
[
self
.
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
)]
logger
.
info
(
f
"before generate
{
messages
}
"
)
outputs
=
self
.
model
.
generate
(
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
self
.
sampling_params
)
outputs
=
self
.
model
.
generate
(
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
self
.
sampling_params
)
output_text
=
[]
output_text
=
[]
...
@@ -157,12 +157,18 @@ class LLMInference:
...
@@ -157,12 +157,18 @@ class LLMInference:
def
chat_stream
(
self
,
prompt
:
str
,
history
=
[]):
def
chat_stream
(
self
,
prompt
:
str
,
history
=
[]):
'''流式服务'''
'''流式服务'''
# HuggingFace
# HuggingFace
logger
.
info
(
"****************** in chat stream *****************"
)
current_length
=
0
current_length
=
0
for
response
,
_
,
_
in
self
.
model
.
stream_chat
(
self
.
tokenizer
,
prompt
,
history
=
history
,
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
logger
.
info
(
f
"stream_chat messages
{
messages
}
"
)
for
response
,
_
,
_
in
self
.
model
.
stream_chat
(
self
.
tokenizer
,
messages
,
history
=
history
,
max_length
=
1024
,
past_key_values
=
None
,
past_key_values
=
None
,
return_past_key_values
=
True
):
return_past_key_values
=
True
):
output_text
=
response
[
current_length
:]
output_text
=
response
[
current_length
:]
output_text
=
self
.
substitution
(
output_text
)
output_text
=
self
.
substitution
(
output_text
)
logger
.
info
(
f
"using transformers chat_stream, Prompt:
{
prompt
!
r
}
, Generated text:
{
output_text
!
r
}
"
)
yield
output_text
yield
output_text
current_length
=
len
(
response
)
current_length
=
len
(
response
)
...
@@ -213,14 +219,15 @@ def llm_inference(args):
...
@@ -213,14 +219,15 @@ def llm_inference(args):
llm_infer
=
LLMInference
(
model
,
llm_infer
=
LLMInference
(
model
,
tokenzier
,
tokenzier
,
sampling_params
,
sampling_params
,
use_vllm
=
use_vllm
,
use_vllm
=
use_vllm
)
stream_chat
=
stream_chat
)
prompt
=
input_json
[
'query'
]
prompt
=
input_json
[
'query'
]
history
=
input_json
[
'history'
]
history
=
input_json
[
'history'
]
logger
.
info
(
f
"prompt
{
prompt
}
"
)
if
stream_chat
:
if
stream_chat
:
text
=
llm_infer
.
stream
_chat
(
prompt
=
prompt
,
history
=
history
)
text
=
await
asyncio
.
to_thread
(
llm_infer
.
chat_
stream
,
prompt
=
prompt
,
history
=
history
)
else
:
else
:
text
=
llm_infer
.
chat
(
prompt
=
prompt
,
history
=
history
)
text
=
await
asyncio
.
to_thread
(
llm_infer
.
chat
,
prompt
=
prompt
,
history
=
history
)
end
=
time
.
time
()
end
=
time
.
time
()
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
prompt
,
text
,
end
-
start
))
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
prompt
,
text
,
end
-
start
))
return
web
.
json_response
({
'text'
:
text
})
return
web
.
json_response
({
'text'
:
text
})
...
@@ -243,8 +250,7 @@ def infer_test(args):
...
@@ -243,8 +250,7 @@ def infer_test(args):
model
,
tokenzier
=
init_model
(
model_path
,
use_vllm
,
tensor_parallel_size
)
model
,
tokenzier
=
init_model
(
model_path
,
use_vllm
,
tensor_parallel_size
)
llm_infer
=
LLMInference
(
model
,
llm_infer
=
LLMInference
(
model
,
tokenzier
,
tokenzier
,
use_vllm
=
use_vllm
,
use_vllm
=
use_vllm
)
stream_chat
=
stream_chat
)
time_first
=
time
.
time
()
time_first
=
time
.
time
()
output_text
=
llm_infer
.
chat
(
args
.
query
)
output_text
=
llm_infer
.
chat
(
args
.
query
)
...
@@ -272,7 +278,7 @@ def parse_args():
...
@@ -272,7 +278,7 @@ def parse_args():
help
=
'config目录'
)
help
=
'config目录'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--query'
,
'--query'
,
default
=
[
'写一首诗'
]
,
default
=
'写一首诗'
,
help
=
'提问的问题.'
)
help
=
'提问的问题.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--DCU_ID'
,
'--DCU_ID'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment