Commit 00f38043 authored by chenych's avatar chenych
Browse files

Change vllm stream chat

parent 75ac58c8
...@@ -47,16 +47,18 @@ if __name__ == "__main__": ...@@ -47,16 +47,18 @@ if __name__ == "__main__":
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_path) config.read(args.config_path)
stream_chat = config.getboolean('llm', 'stream_chat') stream_chat = config.getboolean('llm', 'stream_chat')
func = 'vllm_inference' func = 'vllm_inference'
if args.use_hf: if args.use_hf:
func = 'hf_inference' func = 'hf_inference'
if stream_chat:
func = 'vllm_inference_stream'
api_url = f"http://localhost:8888/{func}" api_url = f"http://localhost:8888/{func}"
response = requests.post(api_url, headers=headers, data=json_str.encode(
"utf-8"), verify=False, stream=stream_chat)
if stream_chat: if stream_chat:
response = requests.get(api_url, headers=headers, data=json_str.encode(
"utf-8"), verify=False, stream=stream_chat)
num_printed_lines = 0 num_printed_lines = 0
for h in get_streaming_response(response): for h in get_streaming_response(response):
clear_line(num_printed_lines) clear_line(num_printed_lines)
...@@ -65,6 +67,8 @@ if __name__ == "__main__": ...@@ -65,6 +67,8 @@ if __name__ == "__main__":
num_printed_lines += 1 num_printed_lines += 1
print(f"Beam candidate {i}: {line!r}", flush=True) print(f"Beam candidate {i}: {line!r}", flush=True)
else: else:
response = requests.get(api_url, headers=headers, data=json_str.encode(
"utf-8"), verify=False, stream=stream_chat)
output = get_response(response) output = get_response(response)
for i, line in enumerate(output): for i, line in enumerate(output):
print(f"Beam candidate {i}: {line!r}", flush=True) print(f"Beam candidate {i}: {line!r}", flush=True)
...@@ -2,9 +2,9 @@ import time ...@@ -2,9 +2,9 @@ import time
import os import os
import configparser import configparser
import argparse import argparse
# import torch import json
import asyncio import asyncio
import uuid
from loguru import logger from loguru import logger
from aiohttp import web from aiohttp import web
# from multiprocessing import Value # from multiprocessing import Value
...@@ -213,12 +213,8 @@ def hf_inference(bind_port, model, tokenizer, stream_chat): ...@@ -213,12 +213,8 @@ def hf_inference(bind_port, model, tokenizer, stream_chat):
web.run_app(app, host='0.0.0.0', port=bind_port) web.run_app(app, host='0.0.0.0', port=bind_port)
def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat): def vllm_inference(bind_port, model, tokenizer, sampling_params):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. ''' '''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
import uuid
import json
from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
async def inference(request): async def inference(request):
start = time.time() start = time.time()
...@@ -237,27 +233,6 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat): ...@@ -237,27 +233,6 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
request_id = str(uuid.uuid4().hex) request_id = str(uuid.uuid4().hex)
results_generator = model.generate(input_text, sampling_params=sampling_params, request_id=request_id) results_generator = model.generate(input_text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = None
async for request_output in results_generator:
# final_output = request_output
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
print(ret)
yield (json.dumps(ret) + "\0").encode("utf-8")
# yield web.json_response({'text': text_outputs})
# assert final_output is not None
# return [output.text for output in final_output.outputs]
if stream_chat:
logger.info("****************** in chat stream *****************")
return StreamingResponse(stream_results())
# text = await stream_results()
# output_text = substitution(text)
# logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
# return web.json_response({'text': output_text})
# Non-streaming case # Non-streaming case
logger.info("****************** in chat ******************") logger.info("****************** in chat ******************")
final_output = None final_output = None
...@@ -282,6 +257,46 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat): ...@@ -282,6 +257,46 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
web.run_app(app, host='0.0.0.0', port=bind_port) web.run_app(app, host='0.0.0.0', port=bind_port)
def vllm_inference_stream(bind_port, model, tokenizer, sampling_params):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
async def inference(request):
input_json = await request.json()
prompt = input_json['query']
# history = input_json['history']
messages = [{"role": "user", "content": prompt}]
logger.info("****************** use vllm ******************")
## generate template
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
logger.info(f"The input_text is {input_text}")
assert model is not None
request_id = str(uuid.uuid4().hex)
results_generator = model.generate(input_text, sampling_params=sampling_params, request_id=request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
# final_output = None
logger.info("****************** in stream_results *****************")
async for request_output in results_generator:
# final_output = request_output
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
print(ret)
yield (json.dumps(ret) + "\0").encode("utf-8")
logger.info("****************** in chat stream *****************")
return StreamingResponse(stream_results())
app = web.Application()
app.add_routes([web.get('/vllm_inference_stream', inference)])
web.run_app(app, host='0.0.0.0', port=bind_port)
def infer_test(args): def infer_test(args):
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(args.config_path) config.read(args.config_path)
...@@ -349,7 +364,10 @@ def main(): ...@@ -349,7 +364,10 @@ def main():
model, tokenizer, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size) model, tokenizer, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size)
if use_vllm: if use_vllm:
vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat) if stream_chat:
vllm_inference_stream(bind_port, model, tokenizer, sampling_params)
else:
vllm_inference(bind_port, model, tokenizer, sampling_params)
else: else:
hf_inference(bind_port, model, tokenizer, stream_chat) hf_inference(bind_port, model, tokenizer, stream_chat)
# infer_test(args) # infer_test(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment