Commit bb0a99c2 authored by Rayyyyy's avatar Rayyyyy
Browse files

Fix bugs

parent 073c3410
......@@ -69,6 +69,7 @@ class LLMInference:
def __init__(self,
model,
tokenizer,
sampling_params,
device: str = 'cuda',
use_vllm: bool = False,
stream_chat: bool = False
......@@ -77,6 +78,7 @@ class LLMInference:
self.device = device
self.model = model
self.tokenizer = tokenizer
self.sampling_params = sampling_params
self.use_vllm = use_vllm
self.stream_chat = stream_chat
......@@ -120,7 +122,7 @@ class LLMInference:
## vllm
prompt_token_ids = [self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
outputs = self.model.generate(prompt_token_ids=prompt_token_ids, sampling_params=self.tokenzier)
outputs = self.model.generate(prompt_token_ids=prompt_token_ids, sampling_params=self.sampling_params)
output_text = []
for output in outputs:
......@@ -134,9 +136,6 @@ class LLMInference:
else:
# transformers
output_text = ''
input_ids = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt").to('cuda')
outputs = self.model.generate(
......@@ -170,10 +169,10 @@ class LLMInference:
def init_model(model_path, use_vllm=False, tp_size=1):
## init models
# huggingface
logger.info("Starting initial model of Llama")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if use_vllm:
try:
# vllm
from vllm import LLM, SamplingParams
......@@ -187,13 +186,11 @@ def init_model(model_path, use_vllm=False, tp_size=1):
enforce_eager=True,
dtype="float16",
tensor_parallel_size=tp_size)
return model, sampling_params
except Exception as e:
logger.error(f"vllm initial failed, {e}")
return model, tokenizer, sampling_params
else:
# huggingface
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda().eval()
return model, tokenizer
return model, tokenizer, None
def llm_inference(args):
......@@ -208,21 +205,23 @@ def llm_inference(args):
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
inference = LLMInference(model,
model, tokenzier, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size)
llm_infer = LLMInference(model,
tokenzier,
sampling_params,
use_vllm=use_vllm,
stream_chat=args.stream_chat)
stream_chat=stream_chat)
async def inference(request):
start = time.time()
input_json = await request.json()
prompt = input_json['query']
history = input_json['history']
if args.stream_chat:
text = inference.stream_chat(prompt=prompt, history=history)
if stream_chat:
text = llm_infer.stream_chat(prompt=prompt, history=history)
else:
text = inference.chat(prompt=prompt, history=history)
text = llm_infer.chat(prompt=prompt, history=history)
end = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
return web.json_response({'text': text})
......@@ -243,13 +242,13 @@ def infer_test(args):
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
inference = LLMInference(model,
llm_infer = LLMInference(model,
tokenzier,
use_vllm=use_vllm,
stream_chat=stream_chat)
time_first = time.time()
output_text = inference.chat(args.query)
output_text = llm_infer.chat(args.query)
time_second = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
args.query, output_text, time_second - time_first))
......@@ -274,12 +273,12 @@ def parse_args():
help='config目录')
parser.add_argument(
'--query',
default=['2000e防火墙恢复密码和忘记IP查询操作'],
default=['写一首诗'],
help='提问的问题.')
parser.add_argument(
'--DCU_ID',
type=str,
default='0,1',
default='6',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
args = parser.parse_args()
return args
......@@ -288,8 +287,8 @@ def parse_args():
def main():
args = parse_args()
set_envs(args.DCU_ID)
# llm_inference(args)
infer_test(args)
llm_inference(args)
# infer_test(args)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment