Unverified Commit 5a5e29de authored by Reid's avatar Reid Committed by GitHub
Browse files

[Misc] refactor examples series - Chat Completion Client With Tools (#16829)


Signed-off-by: default avatarreidliu41 <reid201711@gmail.com>
Co-authored-by: default avatarreidliu41 <reid201711@gmail.com>
parent 3d3ab368
......@@ -17,6 +17,7 @@ vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \
--enable-auto-tool-choice --tool-call-parser hermes
"""
import json
from typing import Any
from openai import OpenAI
......@@ -24,15 +25,6 @@ from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
tools = [{
"type": "function",
"function": {
......@@ -78,39 +70,44 @@ messages = [{
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools)
print("Chat completion results:")
print(chat_completion)
print("\n\n")
def get_current_weather(city: str, state: str, unit: 'str'):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's.")
tool_calls_stream = client.chat.completions.create(messages=messages,
def handle_tool_calls_stream(
client: OpenAI,
messages: list[dict[str, str]],
model: str,
tools: list[dict[str, Any]],
) -> list[Any]:
tool_calls_stream = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=True)
chunks = []
for chunk in tool_calls_stream:
chunks = []
print("chunks: ")
for chunk in tool_calls_stream:
chunks.append(chunk)
if chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls[0])
else:
print(chunk.choices[0].delta)
return chunks
arguments = []
tool_call_idx = -1
for chunk in chunks:
def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
arguments = []
tool_call_idx = -1
print("arguments: ")
for chunk in chunks:
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx:
if tool_call_idx >= 0:
print(
f"streamed tool call arguments: {arguments[tool_call_idx]}"
)
print(f"streamed tool call arguments: "
f"{arguments[tool_call_idx]}")
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("")
if tool_call.id:
......@@ -118,36 +115,63 @@ for chunk in chunks:
if tool_call.function:
if tool_call.function.name:
print(f"streamed tool call name: {tool_call.function.name}")
print(
f"streamed tool call name: {tool_call.function.name}")
if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments
if len(arguments):
print(f"streamed tool call arguments: {arguments[-1]}")
return arguments
print("\n\n")
messages.append({
"role": "assistant",
"tool_calls": chat_completion.choices[0].message.tool_calls
})
def main():
# Initialize OpenAI client
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
# Get available models and select one
models = client.models.list()
model = models.data[0].id
# Now, simulate a tool call
def get_current_weather(city: str, state: str, unit: 'str'):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's.")
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools)
print("-" * 70)
print("Chat completion results:")
print(chat_completion)
print("-" * 70)
# Stream tool calls
chunks = handle_tool_calls_stream(client, messages, model, tools)
print("-" * 70)
# Handle arguments from streamed tool calls
arguments = handle_tool_calls_arguments(chunks)
available_tools = {"get_current_weather": get_current_weather}
if len(arguments):
print(f"streamed tool call arguments: {arguments[-1]}\n")
completion_tool_calls = chat_completion.choices[0].message.tool_calls
for call in completion_tool_calls:
print("-" * 70)
# Add tool call results to the conversation
messages.append({
"role": "assistant",
"tool_calls": chat_completion.choices[0].message.tool_calls
})
# Now, simulate a tool call
available_tools = {"get_current_weather": get_current_weather}
completion_tool_calls = chat_completion.choices[0].message.tool_calls
for call in completion_tool_calls:
tool_to_call = available_tools[call.function.name]
args = json.loads(call.function.arguments)
result = tool_to_call(**args)
print(result)
print("tool_to_call result: ", result)
messages.append({
"role": "tool",
"content": result,
......@@ -155,9 +179,14 @@ for call in completion_tool_calls:
"name": call.function.name
})
chat_completion_2 = client.chat.completions.create(messages=messages,
chat_completion_2 = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=False)
print("\n\n")
print(chat_completion_2)
print("Chat completion2 results:")
print(chat_completion_2)
print("-" * 70)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment