Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
chat_demo
Commits
a92273ba
"testing/python/vscode:/vscode.git/clone" did not exist on "899f7bd5324ef9466949bc872f5fca15d8f7048f"
Commit
a92273ba
authored
Aug 05, 2024
by
chenych
Browse files
Fix AutoTokenizer
parent
3edf4e00
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
151 additions
and
23 deletions
+151
-23
llm_service/inferencer.py
llm_service/inferencer.py
+20
-20
llm_service/vllm_test.py
llm_service/vllm_test.py
+131
-3
No files found.
llm_service/inferencer.py
View file @
a92273ba
...
@@ -8,7 +8,7 @@ import asyncio
...
@@ -8,7 +8,7 @@ import asyncio
from
loguru
import
logger
from
loguru
import
logger
from
aiohttp
import
web
from
aiohttp
import
web
# from multiprocessing import Value
# from multiprocessing import Value
from
transformers
import
AutoModelForCausalLM
,
Auto
t
oken
z
ier
from
transformers
import
AutoModelForCausalLM
,
Auto
T
okeni
z
er
...
@@ -85,13 +85,13 @@ class LLMInference:
...
@@ -85,13 +85,13 @@ class LLMInference:
def
__init__
(
self
,
def
__init__
(
self
,
model
,
model
,
token
z
ier
,
tokeni
z
er
,
device
:
str
=
'cuda'
,
device
:
str
=
'cuda'
,
)
->
None
:
)
->
None
:
self
.
device
=
device
self
.
device
=
device
self
.
model
=
model
self
.
model
=
model
self
.
token
z
ier
=
token
z
ier
self
.
tokeni
z
er
=
tokeni
z
er
def
generate_response
(
self
,
prompt
,
history
=
[]):
def
generate_response
(
self
,
prompt
,
history
=
[]):
print
(
"generate"
)
print
(
"generate"
)
...
@@ -117,7 +117,7 @@ class LLMInference:
...
@@ -117,7 +117,7 @@ class LLMInference:
logger
.
info
(
"****************** in chat ******************"
)
logger
.
info
(
"****************** in chat ******************"
)
try
:
try
:
# transformers
# transformers
input_ids
=
self
.
token
z
ier
.
apply_chat_template
(
input_ids
=
self
.
tokeni
z
er
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
).
to
(
'cuda'
)
messages
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
).
to
(
'cuda'
)
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
input_ids
,
input_ids
,
...
@@ -125,7 +125,7 @@ class LLMInference:
...
@@ -125,7 +125,7 @@ class LLMInference:
)
)
response
=
outputs
[
0
][
input_ids
.
shape
[
-
1
]:]
response
=
outputs
[
0
][
input_ids
.
shape
[
-
1
]:]
generated_text
=
self
.
token
z
ier
.
decode
(
response
,
skip_special_tokens
=
True
)
generated_text
=
self
.
tokeni
z
er
.
decode
(
response
,
skip_special_tokens
=
True
)
output_text
=
substitution
(
generated_text
)
output_text
=
substitution
(
generated_text
)
logger
.
info
(
f
"using transformers, output_text
{
output_text
}
"
)
logger
.
info
(
f
"using transformers, output_text
{
output_text
}
"
)
...
@@ -142,7 +142,7 @@ class LLMInference:
...
@@ -142,7 +142,7 @@ class LLMInference:
current_length
=
0
current_length
=
0
logger
.
info
(
f
"stream_chat messages
{
messages
}
"
)
logger
.
info
(
f
"stream_chat messages
{
messages
}
"
)
for
response
,
_
,
_
in
self
.
model
.
stream_chat
(
self
.
token
z
ier
,
messages
,
history
=
history
,
for
response
,
_
,
_
in
self
.
model
.
stream_chat
(
self
.
tokeni
z
er
,
messages
,
history
=
history
,
max_length
=
1024
,
max_length
=
1024
,
past_key_values
=
None
,
past_key_values
=
None
,
return_past_key_values
=
True
):
return_past_key_values
=
True
):
...
@@ -158,20 +158,20 @@ def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
...
@@ -158,20 +158,20 @@ def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
## init models
## init models
logger
.
info
(
"Starting initial model of LLM"
)
logger
.
info
(
"Starting initial model of LLM"
)
token
z
ier
=
Auto
t
oken
z
ier
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
tokeni
z
er
=
Auto
T
okeni
z
er
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
use_vllm
:
if
use_vllm
:
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
SamplingParams
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
SamplingParams
sampling_params
=
SamplingParams
(
temperature
=
1
,
sampling_params
=
SamplingParams
(
temperature
=
1
,
top_p
=
0.95
,
top_p
=
0.95
,
max_tokens
=
1024
,
max_tokens
=
1024
,
early_stopping
=
False
,
early_stopping
=
False
,
stop_token_ids
=
[
token
z
ier
.
eos_token_id
]
stop_token_ids
=
[
tokeni
z
er
.
eos_token_id
]
)
)
# vLLM基础配置
# vLLM基础配置
args
=
AsyncEngineArgs
(
model_path
)
args
=
AsyncEngineArgs
(
model_path
)
args
.
worker_use_ray
=
False
args
.
worker_use_ray
=
False
args
.
engine_use_ray
=
False
args
.
engine_use_ray
=
False
args
.
token
z
ier
=
model_path
args
.
tokeni
z
er
=
model_path
args
.
tensor_parallel_size
=
tensor_parallel_size
args
.
tensor_parallel_size
=
tensor_parallel_size
args
.
trust_remote_code
=
True
args
.
trust_remote_code
=
True
args
.
enforce_eager
=
True
args
.
enforce_eager
=
True
...
@@ -179,16 +179,16 @@ def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
...
@@ -179,16 +179,16 @@ def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
args
.
dtype
=
'float16'
args
.
dtype
=
'float16'
# 加载模型
# 加载模型
engine
=
AsyncLLMEngine
.
from_engine_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
args
)
return
engine
,
token
z
ier
,
sampling_params
return
engine
,
tokeni
z
er
,
sampling_params
else
:
else
:
# huggingface
# huggingface
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
).
half
().
cuda
().
eval
()
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
).
half
().
cuda
().
eval
()
return
model
,
token
z
ier
,
None
return
model
,
tokeni
z
er
,
None
def
hf_inference
(
bind_port
,
model
,
token
z
ier
,
stream_chat
):
def
hf_inference
(
bind_port
,
model
,
tokeni
z
er
,
stream_chat
):
'''启动 hf Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
'''启动 hf Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
llm_infer
=
LLMInference
(
model
,
token
z
ier
)
llm_infer
=
LLMInference
(
model
,
tokeni
z
er
)
async
def
inference
(
request
):
async
def
inference
(
request
):
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -213,7 +213,7 @@ def hf_inference(bind_port, model, tokenzier, stream_chat):
...
@@ -213,7 +213,7 @@ def hf_inference(bind_port, model, tokenzier, stream_chat):
web
.
run_app
(
app
,
host
=
'0.0.0.0'
,
port
=
bind_port
)
web
.
run_app
(
app
,
host
=
'0.0.0.0'
,
port
=
bind_port
)
def
vllm_inference
(
bind_port
,
model
,
token
z
ier
,
sampling_params
,
stream_chat
):
def
vllm_inference
(
bind_port
,
model
,
tokeni
z
er
,
sampling_params
,
stream_chat
):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
import
uuid
import
uuid
...
@@ -231,7 +231,7 @@ def vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat):
...
@@ -231,7 +231,7 @@ def vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat):
logger
.
info
(
"****************** use vllm ******************"
)
logger
.
info
(
"****************** use vllm ******************"
)
## generate template
## generate template
input_text
=
token
z
ier
.
apply_chat_template
(
input_text
=
tokeni
z
er
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
logger
.
info
(
f
"The input_text is
{
input_text
}
"
)
logger
.
info
(
f
"The input_text is
{
input_text
}
"
)
assert
model
is
not
None
assert
model
is
not
None
...
@@ -285,9 +285,9 @@ def infer_test(args):
...
@@ -285,9 +285,9 @@ def infer_test(args):
stream_chat
=
config
.
getboolean
(
'llm'
,
'stream_chat'
)
stream_chat
=
config
.
getboolean
(
'llm'
,
'stream_chat'
)
logger
.
info
(
f
"Get params: model_path
{
model_path
}
, use_vllm
{
use_vllm
}
, tensor_parallel_size
{
tensor_parallel_size
}
, stream_chat
{
stream_chat
}
"
)
logger
.
info
(
f
"Get params: model_path
{
model_path
}
, use_vllm
{
use_vllm
}
, tensor_parallel_size
{
tensor_parallel_size
}
, stream_chat
{
stream_chat
}
"
)
model
,
token
z
ier
=
init_model
(
model_path
,
use_vllm
,
tensor_parallel_size
)
model
,
tokeni
z
er
=
init_model
(
model_path
,
use_vllm
,
tensor_parallel_size
)
llm_infer
=
LLMInference
(
model
,
llm_infer
=
LLMInference
(
model
,
token
z
ier
,
tokeni
z
er
,
use_vllm
=
use_vllm
)
use_vllm
=
use_vllm
)
time_first
=
time
.
time
()
time_first
=
time
.
time
()
...
@@ -340,11 +340,11 @@ def main():
...
@@ -340,11 +340,11 @@ def main():
stream_chat
=
config
.
getboolean
(
'llm'
,
'stream_chat'
)
stream_chat
=
config
.
getboolean
(
'llm'
,
'stream_chat'
)
logger
.
info
(
f
"Get params: model_path
{
model_path
}
, use_vllm
{
use_vllm
}
, tensor_parallel_size
{
tensor_parallel_size
}
, stream_chat
{
stream_chat
}
"
)
logger
.
info
(
f
"Get params: model_path
{
model_path
}
, use_vllm
{
use_vllm
}
, tensor_parallel_size
{
tensor_parallel_size
}
, stream_chat
{
stream_chat
}
"
)
model
,
token
z
ier
,
sampling_params
=
init_model
(
model_path
,
use_vllm
,
tensor_parallel_size
)
model
,
tokeni
z
er
,
sampling_params
=
init_model
(
model_path
,
use_vllm
,
tensor_parallel_size
)
if
use_vllm
:
if
use_vllm
:
vllm_inference
(
bind_port
,
model
,
token
z
ier
,
sampling_params
,
stream_chat
)
vllm_inference
(
bind_port
,
model
,
tokeni
z
er
,
sampling_params
,
stream_chat
)
else
:
else
:
hf_inference
(
bind_port
,
model
,
token
z
ier
,
sampling_params
,
stream_chat
)
hf_inference
(
bind_port
,
model
,
tokeni
z
er
,
sampling_params
,
stream_chat
)
# infer_test(args)
# infer_test(args)
...
...
llm_service/vllm_test.py
View file @
a92273ba
import
time
import
os
import
configparser
import
argparse
# import torch
import
asyncio
import
uuid
from
typing
import
AsyncGenerator
from
loguru
import
logger
from
aiohttp
import
web
# from multiprocessing import Value
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
SamplingParams
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
COMMON
=
{
"<光合组织登记网址>"
:
"https://www.hieco.com.cn/partner?from=timeline"
,
"<官网>"
:
"https://www.sugon.com/after_sale/policy?sh=1"
,
"<平台联系方式>"
:
"1、访问官网,根据您所在地地址联系平台人员,网址地址:https://www.sugon.com/about/contact;
\n
2、点击人工客服进行咨询;
\n
3、请您拨打中科曙光服务热线400-810-0466联系人工进行咨询。"
,
"<购买与维修的咨询方法>"
:
"1、确定付费处理,可以微信搜索'sugon中科曙光服务'小程序,选择'在线报修'业务
\n
2、先了解价格,可以微信搜索'sugon中科曙光服务'小程序,选择'其他咨询'业务
\n
3、请您拨打中科曙光服务热线400-810-0466"
,
"<服务器续保流程>"
:
"1、微信搜索'sugon中科曙光服务'小程序,选择'延保与登记'业务
\n
2、点击人工客服进行登记
\n
3、请您拨打中科曙光服务热线400-810-0466根据语音提示选择维保与购买"
,
"<XC内外网OS网盘链接>"
:
"【腾讯文档】XC内外网OS网盘链接:https://docs.qq.com/sheet/DTWtXbU1BZHJvWkJm"
,
"<W360-G30机器,安装Win7使用的镜像链接>"
:
"W360-G30机器,安装Win7使用的镜像链接:https://pan.baidu.com/s/1SjHqCP6kJ9KzdJEBZDEynw;提取码:x6m4"
,
"<麒麟系统搜狗输入法下载链接>"
:
"软件下载链接(百度云盘):链接:https://pan.baidu.com/s/18Iluvs4BOAfFET0yFMBeLQ,提取码:bhkf"
,
"<X660 G45 GPU服务器拆解视频网盘链接>"
:
"链接: https://pan.baidu.com/s/1RkRGh4XY1T2oYftGnjLp4w;提取码: v2qi"
,
"<DS800,SANTRICITY存储IBM版本模拟器网盘链接>"
:
"链接:https://pan.baidu.com/s/1euG9HGbPfrVbThEB8BX76g;提取码:o2ya"
,
"<E80-D312(X680-G55)风冷整机组装说明下载链接>"
:
"链接:https://pan.baidu.com/s/17KDpm-Z9lp01WGp9sQaQ4w;提取码:0802"
,
"<X680 G55 风冷相关资料下载链接>"
:
"链接:https://pan.baidu.com/s/1KQ-hxUIbTWNkc0xzrEQLjg;提取码:0802"
,
"<R620 G51刷写EEPROM下载>"
:
"下载链接如下:http://10.2.68.104/tools/bytedance/eeprom/"
,
"<X7450A0服务器售后培训文件网盘链接>"
:
"网盘下载:https://pan.baidu.com/s/1tZJIf_IeQLOWsvuOawhslQ?pwd=kgf1;提取码:kgf1"
,
"<福昕阅读器补丁链接>"
:
"补丁链接: https://pan.baidu.com/s/1QJQ1kHRplhhFly-vxJquFQ,提取码: aupx1"
,
"<W330-H35A_22DB4/W3335HA安装win7网盘链接>"
:
"硬盘链接: https://pan.baidu.com/s/1fDdGPH15mXiw0J-fMmLt6Q提取码: k97i"
,
"<X680 G55服务器售后培训资料网盘链接>"
:
"云盘连接下载:链接:https://pan.baidu.com/s/1gaok13DvNddtkmk6Q-qLYg?pwd=xyhb提取码:xyhb"
,
"<展厅管理员>"
:
"北京-穆淑娟18001053012
\n
天津-马书跃15720934870
\n
昆山-关天琪15304169908
\n
成都-贾小芳18613216313
\n
重庆-李子艺17347743273
\n
安阳-郭永军15824623085
\n
桐乡-李梦瑶18086537055
\n
青岛-陶祉伊15318733259"
,
"<线上预约展厅>"
:
"北京、天津、昆山、成都、重庆、安阳、桐乡、青岛"
,
"<马华>"
:
"联系人:马华,电话:13761751980,邮箱:china@pinbang.com"
,
"<梁静>"
:
"联系人:梁静,电话:18917566297,邮箱:ing.liang@omaten.com"
,
"<徐斌>"
:
"联系人:徐斌,电话:13671166044,邮箱:244898943@qq.com"
,
"<俞晓枫>"
:
"联系人:俞晓枫,电话13750869272,邮箱:857233013@qq.com"
,
"<刘广鹏>"
:
"联系人:刘广鹏,电话13321992411,邮箱:liuguangpeng@pinbang.com"
,
"<马英伟>"
:
"联系人:马英伟,电话:13260021849,邮箱:13260021849@163.com"
,
"<杨洋>"
:
"联系人:杨洋,电话15801203938,邮箱bing523888@163.com"
,
"<展会合规要求>"
:
"1.展品内容:展品内容需符合公司合规要求,展示内容需经过法务合规审查。
\n
2.文字材料内容:文字材料内容需符合公司合规要求,展示内容需经过法务合规审查。
\n
3.展品标签:展品标签内容需符合公司合规要求。
\n
4.礼品内容:礼品内容需符合公司合规要求。
\n
5.视频内容:视频内容需符合公司合规要求,展示内容需经过法务合规审查。
\n
6.讲解词内容:讲解词内容需符合公司合规要求,展示内容需经过法务合规审查。
\n
7.现场发放材料:现场发放的材料内容需符合公司合规要求。
\n
8.展示内容:整体展示内容需要经过法务合规审查。"
,
"<展会质量>"
:
"1.了解展会的组织者背景、往届展会的评价以及提供的服务支持,确保展会的专业性和高效性。
\n
.了解展会的规模、参观人数、行业影响力等因素,以判断展会是否能够提供足够的曝光度和商机。
\n
3.关注同行业其他竞争对手是否参展,以及他们的展位布置、展示内容等信息,以便制定自己的参展策略。
\n
4.展会的日期是否与公司的其他重要活动冲突,以及举办地点是否便于客户和合作伙伴的参观。
\n
5.销售部门会询问展会方提供的宣传渠道和推广服务,以及如何利用这些资源来提升公司及产品的知名度。
\n
6.记录展会期间的重要领导参观、商机线索、合作洽谈、公司拜访预约等信息,跟进后续商业机会。"
,
"<摊位费规则>"
:
"根据展位面积大小,支付相应费用。
\n
展位照明费:支付展位内的照明服务费。
\n
展位保安费:支付展位内的保安服务费。
\n
展位网络使用费:支付展位内网络使用的费用。
\n
展位电源使用费:支付展位内电源使用的费用。"
,
"<展会主题要求>"
:
"展会主题的确定需要符合公司产品和服务业务范围,以确保能够吸引目标客户群体。因此,确定展会主题时,需要考虑以下因素:
\n
专业性:展会的主题应确保专业性,符合行业特点和目标客户的需求。
\n
目标客户群体:展会的主题定位应考虑目标客户群体,确保能够吸引他们的兴趣。
\n
业务重点:展会的主题应突出公司的业务重点和优势,以便更好地推广公司的核心产品或服务。
\n
行业影响力:展会的主题定位需要考虑行业的最新发展趋势,以凸显公司的行业地位和影响力。
\n
往届展会经验:可以参考往届展会的主题定位,总结经验教训,以确定本届展会的主题。
\n
市场部意见:在确定展会主题时,应听取市场部的意见,确保主题符合公司的整体市场战略。
\n
领导意见:还需要考虑公司领导的意见,以确保展会主题符合公司的战略发展方向。"
,
"<办理展商证注意事项>"
:
"人员范围:除公司领导和同事需要办理展商证外,展会运营工作人员也需要办理。
\n
提前准备:展商证的办理需要提前进行,以确保摄影师、摄像师等工作人员可以提前入场进行布置。
\n
办理流程:需要熟悉展商证的办理流程,准备好相关材料,如身份证件等。
\n
数量需求:需要评估所需的展商证数量,避免数量不足或过多的情况。
\n
有效期限:展商证的有效期限需要注意,避免在展期内过期。
\n
存放安全:办理完的展商证需要妥善保管,避免丢失或被他人使用。
\n
使用规范:使用展商证时需要遵守展会相关规定,不得转让给他人使用。
\n
回收处理:展会结束后,需要及时回收展商证,避免泄露相关信息。"
,
"<项目单价要求>"
:
"请注意:无论是否年框供应商,项目单价都不得超过采购部制定的“2024常见活动项目标准单价”,此报价仅可内部使用,严禁外传"
,
"<年框供应商细节表格>"
:
"在线表格https://kdocs.cn/l/camwZE63frNw"
,
"<年框供应商流程>"
:
"1.需求方发出项目需求(大型项目需比稿)
\n
2.外协根据项目需求报价,提供需求方“预算单”(按照基准单价报价,如有发现不按单价情况,解除合同不再使用)
\n
3.需求方确认预算价格,并提交OA市场活动申请
\n
4.外协现场执行
\n
5.需求方现场验收,并签署验收单(物料、设备、人员等实际清单)
\n
6.外协出具结算单(金额与验收单一致,加盖公章)、结案报告、年框合同,作为报销凭证
\n
7.外协请需求方项目负责人填写“满意度调研表”(如无,会影响年度评价)
\n
8.需求方项目经理提交报销"
,
"<市场活动结案报告内容>"
:
"1.项目简介(时间、地点、参与人数等);2.最终会议安排;3.活动各环节现场图片;4.费用相关证明材料(如执行人员、物料照片);5.活动成效汇总;6.活动原始照片/视频网络链接"
,
"<展板设计选择>"
:
"1.去OA文档中心查找一些设计模板; 2. 联系专业的活动服务公司来协助设计"
,
"<餐费标准>"
:
"一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》"
,
""
:
""
,
}
def
init_model
(
model_path
,
use_vllm
=
False
,
tensor_parallel_size
=
1
):
def
init_model
(
model_path
,
use_vllm
=
False
,
tensor_parallel_size
=
1
):
## init models
## init models
# huggingface
# huggingface
...
@@ -54,15 +114,20 @@ def llm_inference(args):
...
@@ -54,15 +114,20 @@ def llm_inference(args):
print
(
text
)
print
(
text
)
assert
model
is
not
None
assert
model
is
not
None
request_id
=
str
(
uuid
.
uuid4
().
hex
)
request_id
=
str
(
uuid
.
uuid4
().
hex
)
results_generator
=
model
.
generate
(
inputs
=
text
,
sampling_params
=
sampling_params
,
request_id
=
request_id
)
## vllm-0.5.0
# results_generator = model.generate(inputs=text, sampling_params=sampling_params, request_id=request_id)
## vllm-0.3.3
results_generator
=
model
.
generate
(
prompt
=
text
,
sampling_params
=
sampling_params
,
request_id
=
request_id
)
# Streaming case
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
request_output
in
results_generator
:
async
for
request_output
in
results_generator
:
prompt
=
request_output
.
prompt
text_outputs
=
[
output
.
text
for
output
in
request_output
.
outputs
]
text_outputs
=
[
output
.
text
for
output
in
request_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
}
ret
=
{
"text"
:
text_outputs
}
yield
(
json
.
dumps
(
ret
)
+
"
\0
"
).
encode
(
"utf-8"
)
print
(
ret
)
# yield (json.dumps(ret) + "\0").encode("utf-8")
yield
web
.
json_response
({
'text'
:
text
})
if
stream_chat
:
if
stream_chat
:
return
StreamingResponse
(
stream_results
())
return
StreamingResponse
(
stream_results
())
...
@@ -86,3 +151,66 @@ def llm_inference(args):
...
@@ -86,3 +151,66 @@ def llm_inference(args):
app
=
web
.
Application
()
app
=
web
.
Application
()
app
.
add_routes
([
web
.
post
(
'/inference'
,
inference
)])
app
.
add_routes
([
web
.
post
(
'/inference'
,
inference
)])
web
.
run_app
(
app
,
host
=
'0.0.0.0'
,
port
=
bind_port
)
web
.
run_app
(
app
,
host
=
'0.0.0.0'
,
port
=
bind_port
)
def
infer_test
(
args
):
config
=
configparser
.
ConfigParser
()
config
.
read
(
args
.
config_path
)
model_path
=
config
[
'llm'
][
'local_llm_path'
]
use_vllm
=
config
.
getboolean
(
'llm'
,
'use_vllm'
)
tensor_parallel_size
=
config
.
getint
(
'llm'
,
'tensor_parallel_size'
)
stream_chat
=
config
.
getboolean
(
'llm'
,
'stream_chat'
)
logger
.
info
(
f
"Get params: model_path
{
model_path
}
, use_vllm
{
use_vllm
}
, tensor_parallel_size
{
tensor_parallel_size
}
, stream_chat
{
stream_chat
}
"
)
model
,
tokenzier
=
init_model
(
model_path
,
use_vllm
,
tensor_parallel_size
)
llm_infer
=
LLMInference
(
model
,
tokenzier
,
use_vllm
=
use_vllm
)
time_first
=
time
.
time
()
output_text
=
llm_infer
.
chat
(
args
.
query
)
time_second
=
time
.
time
()
logger
.
debug
(
'问题:{} 回答:{}
\n
timecost {} '
.
format
(
args
.
query
,
output_text
,
time_second
-
time_first
))
def
set_envs
(
dcu_ids
):
try
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
dcu_ids
logger
.
info
(
f
"Set environment variable CUDA_VISIBLE_DEVICES to
{
dcu_ids
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"
{
e
}
, but got
{
dcu_ids
}
"
)
raise
ValueError
(
f
"
{
e
}
"
)
def
parse_args
():
'''参数'''
parser
=
argparse
.
ArgumentParser
(
description
=
'Feature store for processing directories.'
)
parser
.
add_argument
(
'--config_path'
,
default
=
'../config.ini'
,
help
=
'config目录'
)
parser
.
add_argument
(
'--query'
,
default
=
'写一首诗'
,
help
=
'提问的问题.'
)
parser
.
add_argument
(
'--DCU_ID'
,
type
=
str
,
default
=
'4'
,
help
=
'设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
set_envs
(
args
.
DCU_ID
)
llm_inference
(
args
)
# infer_test(args)
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment