Commit 75ac58c8 authored by chenych's avatar chenych
Browse files

Modify stream chat

parent 4c5a8a74
......@@ -216,7 +216,7 @@ def hf_inference(bind_port, model, tokenizer, stream_chat):
def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
import uuid
import json
from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
......@@ -241,22 +241,22 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = None
async for request_output in results_generator:
final_output = request_output
# final_output = request_output
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8")
yield (json.dumps(ret) + "\0").encode("utf-8")
# yield web.json_response({'text': text_outputs})
assert final_output is not None
return [output.text for output in final_output.outputs]
# assert final_output is not None
# return [output.text for output in final_output.outputs]
if stream_chat:
logger.info("****************** in chat stream *****************")
# return StreamingResponse(stream_results())
text = await stream_results()
output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text})
return StreamingResponse(stream_results())
# text = await stream_results()
# output_text = substitution(text)
# logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
# return web.json_response({'text': output_text})
# Non-streaming case
logger.info("****************** in chat ******************")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment