Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
chat_demo
Commits
e958fb21
Commit
e958fb21
authored
Jul 16, 2024
by
Rayyyyy
Browse files
substitution
parent
7375d90a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
43 deletions
+62
-43
llm_service/inferencer.py
llm_service/inferencer.py
+62
-43
No files found.
llm_service/inferencer.py
View file @
e958fb21
...
...
@@ -7,6 +7,7 @@ from aiohttp import web
import
torch
from
loguru
import
logger
from
fastllm_pytools
import
llm
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoModel
...
...
@@ -48,60 +49,69 @@ class InferenceWrapper:
self
.
stream_chat
=
stream_chat
# huggingface
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# self.model = AutoModelForCausalLM.from_pretrained(model_path,
# trust_remote_code=True,
# torch_dtype=torch.float16).cuda().eval()
model
=
AutoModel
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
).
half
().
cuda
()
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
).
half
().
cuda
()
self
.
model
=
model
.
eval
()
if
self
.
use_vllm
:
## vllm
# from vllm import LLM, SamplingParams
#
# self.sampling_params = SamplingParams(temperature=1, top_p=0.95)
# self.llm = LLM(model=model_path,
# trust_remote_code=True,
# enforce_eager=True,
# tensor_parallel_size=tensor_parallel_size)
## fastllm
from
fastllm_pytools
import
llm
try
:
## vllm
# from vllm import LLM, SamplingParams
# self.sampling_params = SamplingParams(temperature=1, top_p=0.95)
# self.llm = LLM(model=model_path,
# trust_remote_code=True,
# enforce_eager=True,
# tensor_parallel_size=tensor_parallel_size)
## fastllm
if
self
.
stream_chat
:
# fastllm的流式初始化
self
.
model
=
llm
.
model
(
model_path
)
else
:
self
.
model
=
llm
.
from_hf
(
self
.
model
,
self
.
tokenizer
,
dtype
=
"float16"
)
except
Exception
as
e
:
logger
.
error
(
f
"fastllm initial failed,
{
e
}
"
)
def
substitution
(
self
,
output_text
):
import
re
matchObj
=
re
.
split
(
'.*(<.*>).*'
,
output_text
,
re
.
M
|
re
.
I
)
if
matchObj
:
obj
=
matchObj
[
1
]
replace_str
=
COMMON
.
get
(
obj
)
if
replace_str
:
output_text
=
output_text
.
replace
(
obj
,
replace_str
)
logger
.
info
(
f
"
{
obj
}
be replaced
{
replace_str
}
, after
{
output_text
}
"
)
return
output_text
def
chat
(
self
,
prompt
:
str
,
history
=
[]):
'''单轮问答'''
import
re
print
(
"in chat"
)
output_text
=
''
try
:
if
self
.
use_vllm
:
## vllm
# output_text = []
# outputs = self.llm.generate(prompt, self.sampling_params)
# for output in outputs:
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# output_text.append(generated_text)
## fastllm
output_text
=
self
.
model
.
response
(
prompt
)
else
:
output_text
,
_
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
,
history
,
do_sample
=
False
)
matchObj
=
re
.
match
(
'.*(<.*>).*'
,
output_text
)
if
matchObj
:
obj
=
matchObj
.
group
(
1
)
replace_str
=
COMMON
.
get
(
obj
)
output_text
=
self
.
substitution
(
output_text
)
print
(
"output_text"
,
output_text
)
output_text
=
output_text
.
replace
(
obj
,
replace_str
)
logger
.
info
(
f
"
{
obj
}
be replaced
{
replace_str
}
, after
{
output_text
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"chat inference failed,
{
e
}
"
)
return
output_text
def
chat_stream
(
self
,
prompt
:
str
,
history
=
[]):
'''流式服务'''
import
re
...
...
@@ -109,30 +119,16 @@ class InferenceWrapper:
from
fastllm_pytools
import
llm
# Fastllm
for
response
in
self
.
model
.
stream_response
(
prompt
,
history
=
[]):
matchObj
=
re
.
match
(
'.*(<.*>).*'
,
response
)
if
matchObj
:
obj
=
matchObj
.
group
(
1
)
replace_str
=
COMMON
.
get
(
obj
)
response
=
response
.
replace
(
obj
,
replace_str
)
logger
.
info
(
f
"
{
obj
}
be replaced
{
replace_str
}
, after
{
response
}
"
)
response
=
self
.
substitution
(
response
)
yield
response
else
:
# HuggingFace
current_length
=
0
for
response
,
_
,
past_key_values
in
self
.
model
.
stream_chat
(
self
.
tokenizer
,
prompt
,
history
=
history
,
for
response
,
_
,
_
in
self
.
model
.
stream_chat
(
self
.
tokenizer
,
prompt
,
history
=
history
,
past_key_values
=
None
,
return_past_key_values
=
True
):
output_text
=
response
[
current_length
:]
matchObj
=
re
.
match
(
'.*(<.*>).*'
,
output_text
)
if
matchObj
:
obj
=
matchObj
.
group
(
1
)
replace_str
=
COMMON
.
get
(
obj
)
output_text
=
output_text
.
replace
(
obj
,
replace_str
)
logger
.
info
(
f
"
{
obj
}
be replaced
{
replace_str
}
, after
{
output_text
}
"
)
output_text
=
self
.
substitution
(
output_text
)
yield
output_text
current_length
=
len
(
response
)
...
...
@@ -147,13 +143,13 @@ class LLMInference:
)
->
None
:
self
.
device
=
device
self
.
inference
=
InferenceWrapper
(
model_path
=
model_path
,
use_vllm
=
use_vllm
,
stream_chat
=
stream_chat
,
tensor_parallel_size
=
tensor_parallel_size
)
def
generate_response
(
self
,
prompt
,
history
=
[]):
print
(
"generate"
)
output_text
=
''
error
=
''
time_tokenizer
=
time
.
time
()
...
...
@@ -181,6 +177,7 @@ def llm_inference(args):
bind_port
=
int
(
config
[
'default'
][
'bind_port'
])
model_path
=
config
[
'llm'
][
'local_llm_path'
]
use_vllm
=
config
.
getboolean
(
'llm'
,
'use_vllm'
)
print
(
"inference"
)
inference_wrapper
=
InferenceWrapper
(
model_path
,
use_vllm
=
use_vllm
,
tensor_parallel_size
=
1
,
...
...
@@ -204,6 +201,27 @@ def llm_inference(args):
web
.
run_app
(
app
,
host
=
'0.0.0.0'
,
port
=
bind_port
)
def
infer_test
(
args
):
config
=
configparser
.
ConfigParser
()
config
.
read
(
args
.
config_path
)
model_path
=
config
[
'llm'
][
'local_llm_path'
]
use_vllm
=
config
.
getboolean
(
'llm'
,
'use_vllm'
)
tensor_parallel_size
=
config
.
getint
(
'llm'
,
'tensor_parallel_size'
)
inference_wrapper
=
InferenceWrapper
(
model_path
,
use_vllm
=
use_vllm
,
tensor_parallel_size
=
1
,
stream_chat
=
args
.
stream_chat
)
# prompt = "hello,please introduce yourself..."
prompt
=
'65N32-US主板清除CMOS配置的方法'
history
=
[]
time_first
=
time
.
time
()
output_text
=
inference_wrapper
.
chat
(
prompt
)
time_second
=
time
.
time
()
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
prompt
,
output_text
,
time_second
-
time_first
))
def
set_envs
(
dcu_ids
):
try
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
dcu_ids
...
...
@@ -223,7 +241,7 @@ def parse_args():
help
=
'config目录'
)
parser
.
add_argument
(
'--query'
,
default
=
[
'
请问下产品的服务器保修或保修政策?
'
],
default
=
[
'
2000e防火墙恢复密码和忘记IP查询操作
'
],
help
=
'提问的问题.'
)
parser
.
add_argument
(
'--DCU_ID'
,
...
...
@@ -242,6 +260,7 @@ def main():
args
=
parse_args
()
set_envs
(
args
.
DCU_ID
)
llm_inference
(
args
)
# infer_test(args)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment