Commit a92273ba authored by chenych's avatar chenych
Browse files

Fix AutoTokenizer

parent 3edf4e00
......@@ -8,7 +8,7 @@ import asyncio
from loguru import logger
from aiohttp import web
# from multiprocessing import Value
from transformers import AutoModelForCausalLM, Autotokenzier
from transformers import AutoModelForCausalLM, AutoTokenizer
......@@ -85,13 +85,13 @@ class LLMInference:
def __init__(self,
model,
tokenzier,
tokenizer,
device: str = 'cuda',
) -> None:
self.device = device
self.model = model
self.tokenzier = tokenzier
self.tokenizer = tokenizer
def generate_response(self, prompt, history=[]):
print("generate")
......@@ -117,7 +117,7 @@ class LLMInference:
logger.info("****************** in chat ******************")
try:
# transformers
input_ids = self.tokenzier.apply_chat_template(
input_ids = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt").to('cuda')
outputs = self.model.generate(
input_ids,
......@@ -125,7 +125,7 @@ class LLMInference:
)
response = outputs[0][input_ids.shape[-1]:]
generated_text = self.tokenzier.decode(response, skip_special_tokens=True)
generated_text = self.tokenizer.decode(response, skip_special_tokens=True)
output_text = substitution(generated_text)
logger.info(f"using transformers, output_text {output_text}")
......@@ -142,7 +142,7 @@ class LLMInference:
current_length = 0
logger.info(f"stream_chat messages {messages}")
for response, _, _ in self.model.stream_chat(self.tokenzier, messages, history=history,
for response, _, _ in self.model.stream_chat(self.tokenizer, messages, history=history,
max_length=1024,
past_key_values=None,
return_past_key_values=True):
......@@ -158,20 +158,20 @@ def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
## init models
logger.info("Starting initial model of LLM")
tokenzier = Autotokenzier.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if use_vllm:
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
sampling_params = SamplingParams(temperature=1,
top_p=0.95,
max_tokens=1024,
early_stopping=False,
stop_token_ids=[tokenzier.eos_token_id]
stop_token_ids=[tokenizer.eos_token_id]
)
# vLLM基础配置
args = AsyncEngineArgs(model_path)
args.worker_use_ray = False
args.engine_use_ray = False
args.tokenzier = model_path
args.tokenizer = model_path
args.tensor_parallel_size = tensor_parallel_size
args.trust_remote_code = True
args.enforce_eager = True
......@@ -179,16 +179,16 @@ def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
args.dtype = 'float16'
# 加载模型
engine = AsyncLLMEngine.from_engine_args(args)
return engine, tokenzier, sampling_params
return engine, tokenizer, sampling_params
else:
# huggingface
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda().eval()
return model, tokenzier, None
return model, tokenizer, None
def hf_inference(bind_port, model, tokenzier, stream_chat):
def hf_inference(bind_port, model, tokenizer, stream_chat):
'''启动 hf Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
llm_infer = LLMInference(model, tokenzier)
llm_infer = LLMInference(model, tokenizer)
async def inference(request):
start = time.time()
......@@ -213,7 +213,7 @@ def hf_inference(bind_port, model, tokenzier, stream_chat):
web.run_app(app, host='0.0.0.0', port=bind_port)
def vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat):
def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
import uuid
......@@ -231,7 +231,7 @@ def vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat):
logger.info("****************** use vllm ******************")
## generate template
input_text = tokenzier.apply_chat_template(
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
logger.info(f"The input_text is {input_text}")
assert model is not None
......@@ -285,9 +285,9 @@ def infer_test(args):
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
model, tokenizer = init_model(model_path, use_vllm, tensor_parallel_size)
llm_infer = LLMInference(model,
tokenzier,
tokenizer,
use_vllm=use_vllm)
time_first = time.time()
......@@ -340,11 +340,11 @@ def main():
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size)
model, tokenizer, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size)
if use_vllm:
vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat)
vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat)
else:
hf_inference(bind_port, model, tokenzier, sampling_params, stream_chat)
hf_inference(bind_port, model, tokenizer, sampling_params, stream_chat)
# infer_test(args)
......
import time
import os
import configparser
import argparse
# import torch
import asyncio
import uuid
from typing import AsyncGenerator
from loguru import logger
from aiohttp import web
# from multiprocessing import Value
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from fastapi.responses import JSONResponse, Response, StreamingResponse
COMMON = {
"<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline",
"<官网>": "https://www.sugon.com/after_sale/policy?sh=1",
"<平台联系方式>": "1、访问官网,根据您所在地地址联系平台人员,网址地址:https://www.sugon.com/about/contact;\n2、点击人工客服进行咨询;\n3、请您拨打中科曙光服务热线400-810-0466联系人工进行咨询。",
"<购买与维修的咨询方法>": "1、确定付费处理,可以微信搜索'sugon中科曙光服务'小程序,选择'在线报修'业务\n2、先了解价格,可以微信搜索'sugon中科曙光服务'小程序,选择'其他咨询'业务\n3、请您拨打中科曙光服务热线400-810-0466",
"<服务器续保流程>": "1、微信搜索'sugon中科曙光服务'小程序,选择'延保与登记'业务\n2、点击人工客服进行登记\n3、请您拨打中科曙光服务热线400-810-0466根据语音提示选择维保与购买",
"<XC内外网OS网盘链接>": "【腾讯文档】XC内外网OS网盘链接:https://docs.qq.com/sheet/DTWtXbU1BZHJvWkJm",
"<W360-G30机器,安装Win7使用的镜像链接>": "W360-G30机器,安装Win7使用的镜像链接:https://pan.baidu.com/s/1SjHqCP6kJ9KzdJEBZDEynw;提取码:x6m4",
"<麒麟系统搜狗输入法下载链接>": "软件下载链接(百度云盘):链接:https://pan.baidu.com/s/18Iluvs4BOAfFET0yFMBeLQ,提取码:bhkf",
"<X660 G45 GPU服务器拆解视频网盘链接>": "链接: https://pan.baidu.com/s/1RkRGh4XY1T2oYftGnjLp4w;提取码: v2qi",
"<DS800,SANTRICITY存储IBM版本模拟器网盘链接>": "链接:https://pan.baidu.com/s/1euG9HGbPfrVbThEB8BX76g;提取码:o2ya",
"<E80-D312(X680-G55)风冷整机组装说明下载链接>": "链接:https://pan.baidu.com/s/17KDpm-Z9lp01WGp9sQaQ4w;提取码:0802",
"<X680 G55 风冷相关资料下载链接>": "链接:https://pan.baidu.com/s/1KQ-hxUIbTWNkc0xzrEQLjg;提取码:0802",
"<R620 G51刷写EEPROM下载>": "下载链接如下:http://10.2.68.104/tools/bytedance/eeprom/",
"<X7450A0服务器售后培训文件网盘链接>": "网盘下载:https://pan.baidu.com/s/1tZJIf_IeQLOWsvuOawhslQ?pwd=kgf1;提取码:kgf1",
"<福昕阅读器补丁链接>": "补丁链接: https://pan.baidu.com/s/1QJQ1kHRplhhFly-vxJquFQ,提取码: aupx1",
"<W330-H35A_22DB4/W3335HA安装win7网盘链接>": "硬盘链接: https://pan.baidu.com/s/1fDdGPH15mXiw0J-fMmLt6Q提取码: k97i",
"<X680 G55服务器售后培训资料网盘链接>": "云盘连接下载:链接:https://pan.baidu.com/s/1gaok13DvNddtkmk6Q-qLYg?pwd=xyhb提取码:xyhb",
"<展厅管理员>": "北京-穆淑娟18001053012\n天津-马书跃15720934870\n昆山-关天琪15304169908\n成都-贾小芳18613216313\n重庆-李子艺17347743273\n安阳-郭永军15824623085\n桐乡-李梦瑶18086537055\n青岛-陶祉伊15318733259",
"<线上预约展厅>": "北京、天津、昆山、成都、重庆、安阳、桐乡、青岛",
"<马华>": "联系人:马华,电话:13761751980,邮箱:china@pinbang.com",
"<梁静>": "联系人:梁静,电话:18917566297,邮箱:ing.liang@omaten.com",
"<徐斌>": "联系人:徐斌,电话:13671166044,邮箱:244898943@qq.com",
"<俞晓枫>": "联系人:俞晓枫,电话13750869272,邮箱:857233013@qq.com",
"<刘广鹏>": "联系人:刘广鹏,电话13321992411,邮箱:liuguangpeng@pinbang.com",
"<马英伟>": "联系人:马英伟,电话:13260021849,邮箱:13260021849@163.com",
"<杨洋>": "联系人:杨洋,电话15801203938,邮箱bing523888@163.com",
"<展会合规要求>": "1.展品内容:展品内容需符合公司合规要求,展示内容需经过法务合规审查。\n2.文字材料内容:文字材料内容需符合公司合规要求,展示内容需经过法务合规审查。\n3.展品标签:展品标签内容需符合公司合规要求。\n4.礼品内容:礼品内容需符合公司合规要求。\n5.视频内容:视频内容需符合公司合规要求,展示内容需经过法务合规审查。\n6.讲解词内容:讲解词内容需符合公司合规要求,展示内容需经过法务合规审查。\n7.现场发放材料:现场发放的材料内容需符合公司合规要求。\n8.展示内容:整体展示内容需要经过法务合规审查。",
"<展会质量>": "1.了解展会的组织者背景、往届展会的评价以及提供的服务支持,确保展会的专业性和高效性。\n.了解展会的规模、参观人数、行业影响力等因素,以判断展会是否能够提供足够的曝光度和商机。\n3.关注同行业其他竞争对手是否参展,以及他们的展位布置、展示内容等信息,以便制定自己的参展策略。\n4.展会的日期是否与公司的其他重要活动冲突,以及举办地点是否便于客户和合作伙伴的参观。\n5.销售部门会询问展会方提供的宣传渠道和推广服务,以及如何利用这些资源来提升公司及产品的知名度。\n6.记录展会期间的重要领导参观、商机线索、合作洽谈、公司拜访预约等信息,跟进后续商业机会。",
"<摊位费规则>": "根据展位面积大小,支付相应费用。\n展位照明费:支付展位内的照明服务费。\n展位保安费:支付展位内的保安服务费。\n展位网络使用费:支付展位内网络使用的费用。\n展位电源使用费:支付展位内电源使用的费用。",
"<展会主题要求>": "展会主题的确定需要符合公司产品和服务业务范围,以确保能够吸引目标客户群体。因此,确定展会主题时,需要考虑以下因素:\n专业性:展会的主题应确保专业性,符合行业特点和目标客户的需求。\n目标客户群体:展会的主题定位应考虑目标客户群体,确保能够吸引他们的兴趣。\n业务重点:展会的主题应突出公司的业务重点和优势,以便更好地推广公司的核心产品或服务。\n行业影响力:展会的主题定位需要考虑行业的最新发展趋势,以凸显公司的行业地位和影响力。\n往届展会经验:可以参考往届展会的主题定位,总结经验教训,以确定本届展会的主题。\n市场部意见:在确定展会主题时,应听取市场部的意见,确保主题符合公司的整体市场战略。\n领导意见:还需要考虑公司领导的意见,以确保展会主题符合公司的战略发展方向。",
"<办理展商证注意事项>": "人员范围:除公司领导和同事需要办理展商证外,展会运营工作人员也需要办理。\n提前准备:展商证的办理需要提前进行,以确保摄影师、摄像师等工作人员可以提前入场进行布置。\n办理流程:需要熟悉展商证的办理流程,准备好相关材料,如身份证件等。\n数量需求:需要评估所需的展商证数量,避免数量不足或过多的情况。\n有效期限:展商证的有效期限需要注意,避免在展期内过期。\n存放安全:办理完的展商证需要妥善保管,避免丢失或被他人使用。\n使用规范:使用展商证时需要遵守展会相关规定,不得转让给他人使用。\n回收处理:展会结束后,需要及时回收展商证,避免泄露相关信息。",
"<项目单价要求>": "请注意:无论是否年框供应商,项目单价都不得超过采购部制定的“2024常见活动项目标准单价”,此报价仅可内部使用,严禁外传",
"<年框供应商细节表格>": "在线表格https://kdocs.cn/l/camwZE63frNw",
"<年框供应商流程>": "1.需求方发出项目需求(大型项目需比稿)\n2.外协根据项目需求报价,提供需求方“预算单”(按照基准单价报价,如有发现不按单价情况,解除合同不再使用)\n3.需求方确认预算价格,并提交OA市场活动申请\n4.外协现场执行\n5.需求方现场验收,并签署验收单(物料、设备、人员等实际清单)\n6.外协出具结算单(金额与验收单一致,加盖公章)、结案报告、年框合同,作为报销凭证\n7.外协请需求方项目负责人填写“满意度调研表”(如无,会影响年度评价)\n8.需求方项目经理提交报销",
"<市场活动结案报告内容>": "1.项目简介(时间、地点、参与人数等);2.最终会议安排;3.活动各环节现场图片;4.费用相关证明材料(如执行人员、物料照片);5.活动成效汇总;6.活动原始照片/视频网络链接",
"<展板设计选择>": "1.去OA文档中心查找一些设计模板; 2. 联系专业的活动服务公司来协助设计",
"<餐费标准>": "一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》",
"":"",
}
def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
## init models
# huggingface
......@@ -54,15 +114,20 @@ def llm_inference(args):
print(text)
assert model is not None
request_id = str(uuid.uuid4().hex)
results_generator = model.generate(inputs=text, sampling_params=sampling_params, request_id=request_id)
## vllm-0.5.0
# results_generator = model.generate(inputs=text, sampling_params=sampling_params, request_id=request_id)
## vllm-0.3.3
results_generator = model.generate(prompt=text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8")
yield web.json_response({'text': text})
if stream_chat:
return StreamingResponse(stream_results())
......@@ -86,3 +151,66 @@ def llm_inference(args):
app = web.Application()
app.add_routes([web.post('/inference', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def infer_test(args):
config = configparser.ConfigParser()
config.read(args.config_path)
model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
llm_infer = LLMInference(model,
tokenzier,
use_vllm=use_vllm)
time_first = time.time()
output_text = llm_infer.chat(args.query)
time_second = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
args.query, output_text, time_second - time_first))
def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
except Exception as e:
logger.error(f"{e}, but got {dcu_ids}")
raise ValueError(f"{e}")
def parse_args():
'''参数'''
parser = argparse.ArgumentParser(
description='Feature store for processing directories.')
parser.add_argument(
'--config_path',
default='../config.ini',
help='config目录')
parser.add_argument(
'--query',
default='写一首诗',
help='提问的问题.')
parser.add_argument(
'--DCU_ID',
type=str,
default='4',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
args = parser.parse_args()
return args
def main():
args = parse_args()
set_envs(args.DCU_ID)
llm_inference(args)
# infer_test(args)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment