Commit e958fb21 authored by Rayyyyy's avatar Rayyyyy
Browse files

substitution

parent 7375d90a
......@@ -7,6 +7,7 @@ from aiohttp import web
import torch
from loguru import logger
from fastllm_pytools import llm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
......@@ -48,60 +49,69 @@ class InferenceWrapper:
self.stream_chat = stream_chat
# huggingface
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# self.model = AutoModelForCausalLM.from_pretrained(model_path,
# trust_remote_code=True,
# torch_dtype=torch.float16).cuda().eval()
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda()
self.model = model.eval()
if self.use_vllm:
try:
## vllm
# from vllm import LLM, SamplingParams
#
# self.sampling_params = SamplingParams(temperature=1, top_p=0.95)
# self.llm = LLM(model=model_path,
# trust_remote_code=True,
# enforce_eager=True,
# tensor_parallel_size=tensor_parallel_size)
## fastllm
from fastllm_pytools import llm
try:
if self.stream_chat:
# fastllm的流式初始化
self.model = llm.model(model_path)
else:
self.model = llm.from_hf(self.model, self.tokenizer, dtype="float16")
except Exception as e:
logger.error(f"fastllm initial failed, {e}")
def substitution(self, output_text):
import re
matchObj = re.split('.*(<.*>).*', output_text, re.M|re.I)
if matchObj:
obj = matchObj[1]
replace_str = COMMON.get(obj)
if replace_str:
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
return output_text
def chat(self, prompt: str, history=[]):
'''单轮问答'''
import re
print("in chat")
output_text = ''
try:
if self.use_vllm:
## vllm
# output_text = []
# outputs = self.llm.generate(prompt, self.sampling_params)
# for output in outputs:
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# output_text.append(generated_text)
## fastllm
output_text = self.model.response(prompt)
else:
output_text, _ = self.model.chat(self.tokenizer,
prompt,
history,
do_sample=False)
matchObj = re.match('.*(<.*>).*', output_text)
if matchObj:
obj = matchObj.group(1)
replace_str = COMMON.get(obj)
output_text = self.substitution(output_text)
print("output_text", output_text)
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
except Exception as e:
logger.error(f"chat inference failed, {e}")
return output_text
def chat_stream(self, prompt: str, history=[]):
'''流式服务'''
import re
......@@ -109,30 +119,16 @@ class InferenceWrapper:
from fastllm_pytools import llm
# Fastllm
for response in self.model.stream_response(prompt, history=[]):
matchObj = re.match('.*(<.*>).*', response)
if matchObj:
obj = matchObj.group(1)
replace_str = COMMON.get(obj)
response = response.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {response}")
response = self.substitution(response)
yield response
else:
# HuggingFace
current_length = 0
for response, _, past_key_values in self.model.stream_chat(self.tokenizer, prompt, history=history,
for response, _, _ in self.model.stream_chat(self.tokenizer, prompt, history=history,
past_key_values=None,
return_past_key_values=True):
output_text = response[current_length:]
matchObj = re.match('.*(<.*>).*', output_text)
if matchObj:
obj = matchObj.group(1)
replace_str = COMMON.get(obj)
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
output_text = self.substitution(output_text)
yield output_text
current_length = len(response)
......@@ -147,13 +143,13 @@ class LLMInference:
) -> None:
self.device = device
self.inference = InferenceWrapper(model_path=model_path,
use_vllm=use_vllm,
stream_chat=stream_chat,
tensor_parallel_size=tensor_parallel_size)
def generate_response(self, prompt, history=[]):
print("generate")
output_text = ''
error = ''
time_tokenizer = time.time()
......@@ -181,6 +177,7 @@ def llm_inference(args):
bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
print("inference")
inference_wrapper = InferenceWrapper(model_path,
use_vllm=use_vllm,
tensor_parallel_size=1,
......@@ -204,6 +201,27 @@ def llm_inference(args):
web.run_app(app, host='0.0.0.0', port=bind_port)
def infer_test(args):
config = configparser.ConfigParser()
config.read(args.config_path)
model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
inference_wrapper = InferenceWrapper(model_path,
use_vllm=use_vllm,
tensor_parallel_size=1,
stream_chat=args.stream_chat)
# prompt = "hello,please introduce yourself..."
prompt ='65N32-US主板清除CMOS配置的方法'
history = []
time_first = time.time()
output_text = inference_wrapper.chat(prompt)
time_second = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
prompt, output_text, time_second - time_first))
def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
......@@ -223,7 +241,7 @@ def parse_args():
help='config目录')
parser.add_argument(
'--query',
default=['请问下产品的服务器保修或保修政策?'],
default=['2000e防火墙恢复密码和忘记IP查询操作'],
help='提问的问题.')
parser.add_argument(
'--DCU_ID',
......@@ -242,6 +260,7 @@ def main():
args = parse_args()
set_envs(args.DCU_ID)
llm_inference(args)
# infer_test(args)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment