Commit e958fb21 authored by Rayyyyy's avatar Rayyyyy
Browse files

substitution

parent 7375d90a
...@@ -7,6 +7,7 @@ from aiohttp import web ...@@ -7,6 +7,7 @@ from aiohttp import web
import torch import torch
from loguru import logger from loguru import logger
from fastllm_pytools import llm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
...@@ -48,60 +49,69 @@ class InferenceWrapper: ...@@ -48,60 +49,69 @@ class InferenceWrapper:
self.stream_chat = stream_chat self.stream_chat = stream_chat
# huggingface # huggingface
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# self.model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda()
# trust_remote_code=True,
# torch_dtype=torch.float16).cuda().eval()
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
self.model = model.eval() self.model = model.eval()
if self.use_vllm: if self.use_vllm:
## vllm
# from vllm import LLM, SamplingParams
#
# self.sampling_params = SamplingParams(temperature=1, top_p=0.95)
# self.llm = LLM(model=model_path,
# trust_remote_code=True,
# enforce_eager=True,
# tensor_parallel_size=tensor_parallel_size)
## fastllm
from fastllm_pytools import llm
try: try:
## vllm
# from vllm import LLM, SamplingParams
# self.sampling_params = SamplingParams(temperature=1, top_p=0.95)
# self.llm = LLM(model=model_path,
# trust_remote_code=True,
# enforce_eager=True,
# tensor_parallel_size=tensor_parallel_size)
## fastllm
if self.stream_chat: if self.stream_chat:
# fastllm的流式初始化 # fastllm的流式初始化
self.model = llm.model(model_path) self.model = llm.model(model_path)
else: else:
self.model = llm.from_hf(self.model, self.tokenizer, dtype="float16") self.model = llm.from_hf(self.model, self.tokenizer, dtype="float16")
except Exception as e: except Exception as e:
logger.error(f"fastllm initial failed, {e}") logger.error(f"fastllm initial failed, {e}")
def substitution(self, output_text):
import re
matchObj = re.split('.*(<.*>).*', output_text, re.M|re.I)
if matchObj:
obj = matchObj[1]
replace_str = COMMON.get(obj)
if replace_str:
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
return output_text
def chat(self, prompt: str, history=[]): def chat(self, prompt: str, history=[]):
'''单轮问答''' '''单轮问答'''
import re print("in chat")
output_text = '' output_text = ''
try: try:
if self.use_vllm: if self.use_vllm:
## vllm
# output_text = []
# outputs = self.llm.generate(prompt, self.sampling_params)
# for output in outputs:
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# output_text.append(generated_text)
## fastllm
output_text = self.model.response(prompt) output_text = self.model.response(prompt)
else: else:
output_text, _ = self.model.chat(self.tokenizer, output_text, _ = self.model.chat(self.tokenizer,
prompt, prompt,
history, history,
do_sample=False) do_sample=False)
matchObj = re.match('.*(<.*>).*', output_text) output_text = self.substitution(output_text)
if matchObj: print("output_text", output_text)
obj = matchObj.group(1)
replace_str = COMMON.get(obj)
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
except Exception as e: except Exception as e:
logger.error(f"chat inference failed, {e}") logger.error(f"chat inference failed, {e}")
return output_text return output_text
def chat_stream(self, prompt: str, history=[]): def chat_stream(self, prompt: str, history=[]):
'''流式服务''' '''流式服务'''
import re import re
...@@ -109,30 +119,16 @@ class InferenceWrapper: ...@@ -109,30 +119,16 @@ class InferenceWrapper:
from fastllm_pytools import llm from fastllm_pytools import llm
# Fastllm # Fastllm
for response in self.model.stream_response(prompt, history=[]): for response in self.model.stream_response(prompt, history=[]):
matchObj = re.match('.*(<.*>).*', response) response = self.substitution(response)
if matchObj:
obj = matchObj.group(1)
replace_str = COMMON.get(obj)
response = response.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {response}")
yield response yield response
else: else:
# HuggingFace # HuggingFace
current_length = 0 current_length = 0
for response, _, past_key_values in self.model.stream_chat(self.tokenizer, prompt, history=history, for response, _, _ in self.model.stream_chat(self.tokenizer, prompt, history=history,
past_key_values=None, past_key_values=None,
return_past_key_values=True): return_past_key_values=True):
output_text = response[current_length:] output_text = response[current_length:]
matchObj = re.match('.*(<.*>).*', output_text) output_text = self.substitution(output_text)
if matchObj:
obj = matchObj.group(1)
replace_str = COMMON.get(obj)
output_text = output_text.replace(obj, replace_str)
logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
yield output_text yield output_text
current_length = len(response) current_length = len(response)
...@@ -147,13 +143,13 @@ class LLMInference: ...@@ -147,13 +143,13 @@ class LLMInference:
) -> None: ) -> None:
self.device = device self.device = device
self.inference = InferenceWrapper(model_path=model_path, self.inference = InferenceWrapper(model_path=model_path,
use_vllm=use_vllm, use_vllm=use_vllm,
stream_chat=stream_chat, stream_chat=stream_chat,
tensor_parallel_size=tensor_parallel_size) tensor_parallel_size=tensor_parallel_size)
def generate_response(self, prompt, history=[]): def generate_response(self, prompt, history=[]):
print("generate")
output_text = '' output_text = ''
error = '' error = ''
time_tokenizer = time.time() time_tokenizer = time.time()
...@@ -181,6 +177,7 @@ def llm_inference(args): ...@@ -181,6 +177,7 @@ def llm_inference(args):
bind_port = int(config['default']['bind_port']) bind_port = int(config['default']['bind_port'])
model_path = config['llm']['local_llm_path'] model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm') use_vllm = config.getboolean('llm', 'use_vllm')
print("inference")
inference_wrapper = InferenceWrapper(model_path, inference_wrapper = InferenceWrapper(model_path,
use_vllm=use_vllm, use_vllm=use_vllm,
tensor_parallel_size=1, tensor_parallel_size=1,
...@@ -204,6 +201,27 @@ def llm_inference(args): ...@@ -204,6 +201,27 @@ def llm_inference(args):
web.run_app(app, host='0.0.0.0', port=bind_port) web.run_app(app, host='0.0.0.0', port=bind_port)
def infer_test(args):
config = configparser.ConfigParser()
config.read(args.config_path)
model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
inference_wrapper = InferenceWrapper(model_path,
use_vllm=use_vllm,
tensor_parallel_size=1,
stream_chat=args.stream_chat)
# prompt = "hello,please introduce yourself..."
prompt ='65N32-US主板清除CMOS配置的方法'
history = []
time_first = time.time()
output_text = inference_wrapper.chat(prompt)
time_second = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
prompt, output_text, time_second - time_first))
def set_envs(dcu_ids): def set_envs(dcu_ids):
try: try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
...@@ -223,7 +241,7 @@ def parse_args(): ...@@ -223,7 +241,7 @@ def parse_args():
help='config目录') help='config目录')
parser.add_argument( parser.add_argument(
'--query', '--query',
default=['请问下产品的服务器保修或保修政策?'], default=['2000e防火墙恢复密码和忘记IP查询操作'],
help='提问的问题.') help='提问的问题.')
parser.add_argument( parser.add_argument(
'--DCU_ID', '--DCU_ID',
...@@ -242,6 +260,7 @@ def main(): ...@@ -242,6 +260,7 @@ def main():
args = parse_args() args = parse_args()
set_envs(args.DCU_ID) set_envs(args.DCU_ID)
llm_inference(args) llm_inference(args)
# infer_test(args)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment