Unverified Commit 5a5e29de authored by Reid's avatar Reid Committed by GitHub
Browse files

[Misc] refactor examples series - Chat Completion Client With Tools (#16829)


Signed-off-by: default avatarreidliu41 <reid201711@gmail.com>
Co-authored-by: default avatarreidliu41 <reid201711@gmail.com>
parent 3d3ab368
...@@ -17,6 +17,7 @@ vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ ...@@ -17,6 +17,7 @@ vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \
--enable-auto-tool-choice --tool-call-parser hermes --enable-auto-tool-choice --tool-call-parser hermes
""" """
import json import json
from typing import Any
from openai import OpenAI from openai import OpenAI
...@@ -24,15 +25,6 @@ from openai import OpenAI ...@@ -24,15 +25,6 @@ from openai import OpenAI
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
tools = [{ tools = [{
"type": "function", "type": "function",
"function": { "function": {
...@@ -78,86 +70,123 @@ messages = [{ ...@@ -78,86 +70,123 @@ messages = [{
"Can you tell me what the temperate will be in Dallas, in fahrenheit?" "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}] }]
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools)
print("Chat completion results:")
print(chat_completion)
print("\n\n")
tool_calls_stream = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=True)
chunks = []
for chunk in tool_calls_stream:
chunks.append(chunk)
if chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls[0])
else:
print(chunk.choices[0].delta)
arguments = []
tool_call_idx = -1
for chunk in chunks:
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx:
if tool_call_idx >= 0:
print(
f"streamed tool call arguments: {arguments[tool_call_idx]}"
)
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("")
if tool_call.id:
print(f"streamed tool call id: {tool_call.id} ")
if tool_call.function:
if tool_call.function.name:
print(f"streamed tool call name: {tool_call.function.name}")
if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments
if len(arguments):
print(f"streamed tool call arguments: {arguments[-1]}")
print("\n\n")
messages.append({
"role": "assistant",
"tool_calls": chat_completion.choices[0].message.tool_calls
})
# Now, simulate a tool call
def get_current_weather(city: str, state: str, unit: 'str'): def get_current_weather(city: str, state: str, unit: 'str'):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's.") "partly cloudly, with highs in the 90's.")
available_tools = {"get_current_weather": get_current_weather} def handle_tool_calls_stream(
client: OpenAI,
completion_tool_calls = chat_completion.choices[0].message.tool_calls messages: list[dict[str, str]],
for call in completion_tool_calls: model: str,
tool_to_call = available_tools[call.function.name] tools: list[dict[str, Any]],
args = json.loads(call.function.arguments) ) -> list[Any]:
result = tool_to_call(**args) tool_calls_stream = client.chat.completions.create(messages=messages,
print(result) model=model,
tools=tools,
stream=True)
chunks = []
print("chunks: ")
for chunk in tool_calls_stream:
chunks.append(chunk)
if chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls[0])
else:
print(chunk.choices[0].delta)
return chunks
def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
arguments = []
tool_call_idx = -1
print("arguments: ")
for chunk in chunks:
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx:
if tool_call_idx >= 0:
print(f"streamed tool call arguments: "
f"{arguments[tool_call_idx]}")
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("")
if tool_call.id:
print(f"streamed tool call id: {tool_call.id} ")
if tool_call.function:
if tool_call.function.name:
print(
f"streamed tool call name: {tool_call.function.name}")
if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments
return arguments
def main():
# Initialize OpenAI client
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
# Get available models and select one
models = client.models.list()
model = models.data[0].id
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools)
print("-" * 70)
print("Chat completion results:")
print(chat_completion)
print("-" * 70)
# Stream tool calls
chunks = handle_tool_calls_stream(client, messages, model, tools)
print("-" * 70)
# Handle arguments from streamed tool calls
arguments = handle_tool_calls_arguments(chunks)
if len(arguments):
print(f"streamed tool call arguments: {arguments[-1]}\n")
print("-" * 70)
# Add tool call results to the conversation
messages.append({ messages.append({
"role": "tool", "role": "assistant",
"content": result, "tool_calls": chat_completion.choices[0].message.tool_calls
"tool_call_id": call.id,
"name": call.function.name
}) })
chat_completion_2 = client.chat.completions.create(messages=messages, # Now, simulate a tool call
model=model, available_tools = {"get_current_weather": get_current_weather}
tools=tools,
stream=False) completion_tool_calls = chat_completion.choices[0].message.tool_calls
print("\n\n") for call in completion_tool_calls:
print(chat_completion_2) tool_to_call = available_tools[call.function.name]
args = json.loads(call.function.arguments)
result = tool_to_call(**args)
print("tool_to_call result: ", result)
messages.append({
"role": "tool",
"content": result,
"tool_call_id": call.id,
"name": call.function.name
})
chat_completion_2 = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=False)
print("Chat completion2 results:")
print(chat_completion_2)
print("-" * 70)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment