"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "a9d74b669e71d5db989f8766f9535620f20238c5"
Commit bb0a99c2 authored by Rayyyyy's avatar Rayyyyy
Browse files

Fix bugs

parent 073c3410
...@@ -69,6 +69,7 @@ class LLMInference: ...@@ -69,6 +69,7 @@ class LLMInference:
def __init__(self, def __init__(self,
model, model,
tokenizer, tokenizer,
sampling_params,
device: str = 'cuda', device: str = 'cuda',
use_vllm: bool = False, use_vllm: bool = False,
stream_chat: bool = False stream_chat: bool = False
...@@ -77,6 +78,7 @@ class LLMInference: ...@@ -77,6 +78,7 @@ class LLMInference:
self.device = device self.device = device
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.sampling_params = sampling_params
self.use_vllm = use_vllm self.use_vllm = use_vllm
self.stream_chat = stream_chat self.stream_chat = stream_chat
...@@ -120,7 +122,7 @@ class LLMInference: ...@@ -120,7 +122,7 @@ class LLMInference:
## vllm ## vllm
prompt_token_ids = [self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)] prompt_token_ids = [self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
outputs = self.model.generate(prompt_token_ids=prompt_token_ids, sampling_params=self.tokenzier) outputs = self.model.generate(prompt_token_ids=prompt_token_ids, sampling_params=self.sampling_params)
output_text = [] output_text = []
for output in outputs: for output in outputs:
...@@ -134,9 +136,6 @@ class LLMInference: ...@@ -134,9 +136,6 @@ class LLMInference:
else: else:
# transformers # transformers
output_text = ''
input_ids = self.tokenizer.apply_chat_template( input_ids = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt").to('cuda') messages, add_generation_prompt=True, return_tensors="pt").to('cuda')
outputs = self.model.generate( outputs = self.model.generate(
...@@ -170,30 +169,28 @@ class LLMInference: ...@@ -170,30 +169,28 @@ class LLMInference:
def init_model(model_path, use_vllm=False, tp_size=1): def init_model(model_path, use_vllm=False, tp_size=1):
## init models ## init models
# huggingface
logger.info("Starting initial model of Llama")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if use_vllm: if use_vllm:
try: # vllm
# vllm from vllm import LLM, SamplingParams
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=1,
sampling_params = SamplingParams(temperature=1, top_p=0.95,
top_p=0.95, max_tokens=1024,
max_tokens=1024, stop_token_ids=[tokenizer.eos_token_id])
stop_token_ids=[tokenizer.eos_token_id])
model = LLM(model=model_path,
model = LLM(model=model_path, trust_remote_code=True,
trust_remote_code=True, enforce_eager=True,
enforce_eager=True, dtype="float16",
dtype="float16", tensor_parallel_size=tp_size)
tensor_parallel_size=tp_size) return model, tokenizer, sampling_params
return model, sampling_params
except Exception as e:
logger.error(f"vllm initial failed, {e}")
else: else:
# huggingface
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda().eval() model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda().eval()
return model, tokenizer return model, tokenizer, None
def llm_inference(args): def llm_inference(args):
...@@ -208,21 +205,23 @@ def llm_inference(args): ...@@ -208,21 +205,23 @@ def llm_inference(args):
stream_chat = config.getboolean('llm', 'stream_chat') stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}") logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size) model, tokenzier, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size)
inference = LLMInference(model, llm_infer = LLMInference(model,
tokenzier, tokenzier,
sampling_params,
use_vllm=use_vllm, use_vllm=use_vllm,
stream_chat=args.stream_chat) stream_chat=stream_chat)
async def inference(request): async def inference(request):
start = time.time() start = time.time()
input_json = await request.json() input_json = await request.json()
prompt = input_json['query'] prompt = input_json['query']
history = input_json['history'] history = input_json['history']
if args.stream_chat: if stream_chat:
text = inference.stream_chat(prompt=prompt, history=history) text = llm_infer.stream_chat(prompt=prompt, history=history)
else: else:
text = inference.chat(prompt=prompt, history=history) text = llm_infer.chat(prompt=prompt, history=history)
end = time.time() end = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start)) logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
return web.json_response({'text': text}) return web.json_response({'text': text})
...@@ -243,13 +242,13 @@ def infer_test(args): ...@@ -243,13 +242,13 @@ def infer_test(args):
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}") logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size) model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
inference = LLMInference(model, llm_infer = LLMInference(model,
tokenzier, tokenzier,
use_vllm=use_vllm, use_vllm=use_vllm,
stream_chat=stream_chat) stream_chat=stream_chat)
time_first = time.time() time_first = time.time()
output_text = inference.chat(args.query) output_text = llm_infer.chat(args.query)
time_second = time.time() time_second = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format( logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
args.query, output_text, time_second - time_first)) args.query, output_text, time_second - time_first))
...@@ -274,12 +273,12 @@ def parse_args(): ...@@ -274,12 +273,12 @@ def parse_args():
help='config目录') help='config目录')
parser.add_argument( parser.add_argument(
'--query', '--query',
default=['2000e防火墙恢复密码和忘记IP查询操作'], default=['写一首诗'],
help='提问的问题.') help='提问的问题.')
parser.add_argument( parser.add_argument(
'--DCU_ID', '--DCU_ID',
type=str, type=str,
default='0,1', default='6',
help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"') help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -288,8 +287,8 @@ def parse_args(): ...@@ -288,8 +287,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
set_envs(args.DCU_ID) set_envs(args.DCU_ID)
# llm_inference(args) llm_inference(args)
infer_test(args) # infer_test(args)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment