Commit 75ac58c8 authored by chenych's avatar chenych
Browse files

Modify stream chat

parent 4c5a8a74
...@@ -216,7 +216,7 @@ def hf_inference(bind_port, model, tokenizer, stream_chat): ...@@ -216,7 +216,7 @@ def hf_inference(bind_port, model, tokenizer, stream_chat):
def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat): def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
'''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. ''' '''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
import uuid import uuid
import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
...@@ -241,22 +241,22 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat): ...@@ -241,22 +241,22 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = None final_output = None
async for request_output in results_generator: async for request_output in results_generator:
final_output = request_output # final_output = request_output
text_outputs = [output.text for output in request_output.outputs] text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs} ret = {"text": text_outputs}
print(ret) print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
# yield web.json_response({'text': text_outputs}) # yield web.json_response({'text': text_outputs})
assert final_output is not None # assert final_output is not None
return [output.text for output in final_output.outputs] # return [output.text for output in final_output.outputs]
if stream_chat: if stream_chat:
logger.info("****************** in chat stream *****************") logger.info("****************** in chat stream *****************")
# return StreamingResponse(stream_results()) return StreamingResponse(stream_results())
text = await stream_results() # text = await stream_results()
output_text = substitution(text) # output_text = substitution(text)
logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start)) # logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, output_text, time.time() - start))
return web.json_response({'text': output_text}) # return web.json_response({'text': output_text})
# Non-streaming case # Non-streaming case
logger.info("****************** in chat ******************") logger.info("****************** in chat ******************")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment