"src/lib/vscode:/vscode.git/clone" did not exist on "9d58bb1c6657e7fc576eea7b34ca5e7e2eef8fd2"
Commit e08c9060 authored by chenych's avatar chenych
Browse files

Modify codes

parent f0863458
......@@ -228,7 +228,6 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
# history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************")
## generate template
input_text = tokenizer.apply_chat_template(
......@@ -248,12 +247,15 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8")
# yield web.json_response({'text': text_outputs})
return final_output
assert final_output is not None
return [output.text for output in final_output.outputs]
if stream_chat:
logger.info("****************** in chat stream *****************")
# return StreamingResponse(stream_results())
output_text = await stream_results()
text = await stream_results()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text})
# Non-streaming case
......@@ -269,9 +271,10 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
assert final_output is not None
text = [output.text for output in final_output.outputs]
end = time.time()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text})
app = web.Application()
......
......@@ -13,8 +13,6 @@ from aiohttp import web
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from fastapi.responses import JSONResponse, Response, StreamingResponse
COMMON = {
"<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline",
......@@ -113,24 +111,30 @@ def llm_inference(args):
messages, tokenize=False, add_generation_prompt=True)
print(text)
assert model is not None
request_id = str(uuid.uuid4().hex)
## vllm-0.5.0
# results_generator = model.generate(inputs=text, sampling_params=sampling_params, request_id=request_id)
## vllm-0.3.3
results_generator = model.generate(prompt=text, sampling_params=sampling_params, request_id=request_id)
results_generator = model.generate(text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = None
async for request_output in results_generator:
final_output = request_output
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8")
yield web.json_response({'text': text})
# yield web.json_response({'text': text_outputs})
assert final_output is not None
return [output.text for output in final_output.outputs]
if stream_chat:
return StreamingResponse(stream_results())
logger.info("****************** in chat stream *****************")
# return StreamingResponse(stream_results())
text = await stream_results()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text})
# Non-streaming case
final_output = None
......@@ -153,28 +157,6 @@ def llm_inference(args):
web.run_app(app, host='0.0.0.0', port=bind_port)
def infer_test(args):
config = configparser.ConfigParser()
config.read(args.config_path)
model_path = config['llm']['local_llm_path']
use_vllm = config.getboolean('llm', 'use_vllm')
tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
stream_chat = config.getboolean('llm', 'stream_chat')
logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
llm_infer = LLMInference(model,
tokenzier,
use_vllm=use_vllm)
time_first = time.time()
output_text = llm_infer.chat(args.query)
time_second = time.time()
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
args.query, output_text, time_second - time_first))
def set_envs(dcu_ids):
try:
os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
......@@ -209,7 +191,6 @@ def main():
args = parse_args()
set_envs(args.DCU_ID)
llm_inference(args)
# infer_test(args)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment