Commit f0863458 authored by chenych's avatar chenych
Browse files

Fix stream chat

parent a92273ba
......@@ -240,17 +240,21 @@ def vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat):
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = None
async for request_output in results_generator:
final_output = request_output
text_outputs = [output.text for output in request_output.outputs]
ret = {"text": text_outputs}
print(ret)
# yield (json.dumps(ret) + "\0").encode("utf-8")
yield web.json_response({'text': text})
# yield web.json_response({'text': text_outputs})
return final_output
if stream_chat:
logger.info("****************** in chat stream *****************")
return StreamingResponse(stream_results())
# return StreamingResponse(stream_results())
output_text = await stream_results()
return web.json_response({'text': output_text})
# Non-streaming case
logger.info("****************** in chat ******************")
......@@ -344,7 +348,7 @@ def main():
if use_vllm:
vllm_inference(bind_port, model, tokenizer, sampling_params, stream_chat)
else:
hf_inference(bind_port, model, tokenizer, sampling_params, stream_chat)
hf_inference(bind_port, model, tokenizer, stream_chat)
# infer_test(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment