Unverified Commit e111d5b0 authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

[CLI] Use streaming in CLI chat and completion commands (#23769)


Signed-off-by: default avatarsimon-mo <simon.mo@hey.com>
parent a904ea78
...@@ -45,6 +45,28 @@ def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: ...@@ -45,6 +45,28 @@ def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
return model_name, openai_client return model_name, openai_client
def _print_chat_stream(stream) -> str:
output = ""
for chunk in stream:
delta = chunk.choices[0].delta
if delta.content:
output += delta.content
print(delta.content, end="", flush=True)
print()
return output
def _print_completion_stream(stream) -> str:
output = ""
for chunk in stream:
text = chunk.choices[0].text
if text is not None:
output += text
print(text, end="", flush=True)
print()
return output
def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
conversation: list[ChatCompletionMessageParam] = [] conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None: if system_prompt is not None:
...@@ -58,14 +80,11 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: ...@@ -58,14 +80,11 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
break break
conversation.append({"role": "user", "content": input_message}) conversation.append({"role": "user", "content": input_message})
chat_completion = client.chat.completions.create(model=model_name, stream = client.chat.completions.create(model=model_name,
messages=conversation) messages=conversation,
stream=True)
response_message = chat_completion.choices[0].message output = _print_chat_stream(stream)
output = response_message.content conversation.append({"role": "assistant", "content": output})
conversation.append(response_message) # type: ignore
print(output)
def _add_query_options( def _add_query_options(
...@@ -108,9 +127,11 @@ class ChatCommand(CLISubcommand): ...@@ -108,9 +127,11 @@ class ChatCommand(CLISubcommand):
if args.quick: if args.quick:
conversation.append({"role": "user", "content": args.quick}) conversation.append({"role": "user", "content": args.quick})
chat_completion = client.chat.completions.create( stream = client.chat.completions.create(model=model_name,
model=model_name, messages=conversation) messages=conversation,
print(chat_completion.choices[0].message.content) stream=True)
output = _print_chat_stream(stream)
conversation.append({"role": "assistant", "content": output})
return return
print("Please enter a message for the chat model:") print("Please enter a message for the chat model:")
...@@ -121,14 +142,11 @@ class ChatCommand(CLISubcommand): ...@@ -121,14 +142,11 @@ class ChatCommand(CLISubcommand):
break break
conversation.append({"role": "user", "content": input_message}) conversation.append({"role": "user", "content": input_message})
chat_completion = client.chat.completions.create( stream = client.chat.completions.create(model=model_name,
model=model_name, messages=conversation) messages=conversation,
stream=True)
response_message = chat_completion.choices[0].message output = _print_chat_stream(stream)
output = response_message.content conversation.append({"role": "assistant", "content": output})
conversation.append(response_message) # type: ignore
print(output)
@staticmethod @staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
...@@ -168,9 +186,10 @@ class CompleteCommand(CLISubcommand): ...@@ -168,9 +186,10 @@ class CompleteCommand(CLISubcommand):
model_name, client = _interactive_cli(args) model_name, client = _interactive_cli(args)
if args.quick: if args.quick:
completion = client.completions.create(model=model_name, stream = client.completions.create(model=model_name,
prompt=args.quick) prompt=args.quick,
print(completion.choices[0].text) stream=True)
_print_completion_stream(stream)
return return
print("Please enter prompt to complete:") print("Please enter prompt to complete:")
...@@ -179,10 +198,10 @@ class CompleteCommand(CLISubcommand): ...@@ -179,10 +198,10 @@ class CompleteCommand(CLISubcommand):
input_prompt = input("> ") input_prompt = input("> ")
except EOFError: except EOFError:
break break
completion = client.completions.create(model=model_name, stream = client.completions.create(model=model_name,
prompt=input_prompt) prompt=input_prompt,
output = completion.choices[0].text stream=True)
print(output) _print_completion_stream(stream)
@staticmethod @staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment