"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "a3a832a9359bc4a51f5fa5498d960b5ce073385d"
Commit e08c9060 authored by chenych's avatar chenych
Browse files

Modify codes

parent f0863458
...@@ -228,7 +228,6 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat): ...@@ -228,7 +228,6 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
# history = input_json['history'] # history = input_json['history']
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************") logger.info("****************** use vllm ******************")
## generate template ## generate template
input_text = tokenizer.apply_chat_template( input_text = tokenizer.apply_chat_template(
...@@ -248,12 +247,15 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat): ...@@ -248,12 +247,15 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
print(ret) print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8") # yield (json.dumps(ret) + "\0").encode("utf-8")
# yield web.json_response({'text': text_outputs}) # yield web.json_response({'text': text_outputs})
return final_output assert final_output is not None
return [output.text for output in final_output.outputs]
if stream_chat: if stream_chat:
logger.info("****************** in chat stream *****************") logger.info("****************** in chat stream *****************")
# return StreamingResponse(stream_results()) # return StreamingResponse(stream_results())
output_text = await stream_results() text = await stream_results()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text}) return web.json_response({'text': output_text})
# Non-streaming case # Non-streaming case
...@@ -269,9 +271,10 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat): ...@@ -269,9 +271,10 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
assert final_output is not None assert final_output is not None
text = [output.text for output in final_output.outputs] text = [output.text for output in final_output.outputs]
end = time.time()
output_text = substitution(text) output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text}) return web.json_response({'text': output_text})
app = web.Application() app = web.Application()
......
...@@ -13,8 +13,6 @@ from aiohttp import web ...@@ -13,8 +13,6 @@ from aiohttp import web
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from fastapi.responses import JSONResponse, Response, StreamingResponse
COMMON = { COMMON = {
"<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline", "<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline",
...@@ -113,24 +111,30 @@ def llm_inference(args): ...@@ -113,24 +111,30 @@ def llm_inference(args):
messages, tokenize=False, add_generation_prompt=True) messages, tokenize=False, add_generation_prompt=True)
print(text) print(text)
assert model is not None assert model is not None
request_id = str(uuid.uuid4().hex) request_id = str(uuid.uuid4().hex)
## vllm-0.5.0 results_generator = model.generate(text, sampling_params=sampling_params, request_id=request_id)
# results_generator = model.generate(inputs=text, sampling_params=sampling_params, request_id=request_id)
## vllm-0.3.3
results_generator = model.generate(prompt=text, sampling_params=sampling_params, request_id=request_id)
# Streaming case # Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = None
async for request_output in results_generator: async for request_output in results_generator:
final_output = request_output
text_outputs = [output.text for output in request_output.outputs] text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs} ret = {"text": text_outputs}
print(ret) print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8") # yield (json.dumps(ret) + "\0").encode("utf-8")
yield web.json_response({'text': text}) # yield web.json_response({'text': text_outputs})
assert final_output is not None
return [output.text for output in final_output.outputs]
if stream_chat: if stream_chat:
return StreamingResponse(stream_results()) logger.info("****************** in chat stream *****************")
# return StreamingResponse(stream_results())
text = await stream_results()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text})
# Non-streaming case # Non-streaming case
final_output = None final_output = None
...@@ -153,28 +157,6 @@ def llm_inference(args): ...@@ -153,28 +157,6 @@ def llm_inference(args):
web.run_app(app, host='0.0.0.0', port=bind_port) web.run_app(app, host='0.0.0.0', port=bind_port)
def infer_test(args):
config = configparser.ConfigParser()
config.read(args.config_path)
model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
llm_infer = LLMInference(model,
tokenzier,
use_vllm=use_vllm)
time_first = time.time()
output_text = llm_infer.chat(args.query)
time_second = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
args.query, output_text, time_second - time_first))
def set_envs(dcu_ids): def set_envs(dcu_ids):
try: try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
...@@ -209,7 +191,6 @@ def main(): ...@@ -209,7 +191,6 @@ def main():
args = parse_args() args = parse_args()
set_envs(args.DCU_ID) set_envs(args.DCU_ID)
llm_inference(args) llm_inference(args)
# infer_test(args)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment