"cmake/vscode:/vscode.git/clone" did not exist on "c7914d30f90bc47f1c959d3330666885a0034f7d"
Unverified Commit 5a5e29de authored by Reid's avatar Reid Committed by GitHub
Browse files

[Misc] refactor examples series - Chat Completion Client With Tools (#16829)


Signed-off-by: default avatarreidliu41 <reid201711@gmail.com>
Co-authored-by: default avatarreidliu41 <reid201711@gmail.com>
parent 3d3ab368
...@@ -17,6 +17,7 @@ vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ ...@@ -17,6 +17,7 @@ vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \
--enable-auto-tool-choice --tool-call-parser hermes --enable-auto-tool-choice --tool-call-parser hermes
""" """
import json import json
from typing import Any
from openai import OpenAI from openai import OpenAI
...@@ -24,15 +25,6 @@ from openai import OpenAI ...@@ -24,15 +25,6 @@ from openai import OpenAI
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
tools = [{ tools = [{
"type": "function", "type": "function",
"function": { "function": {
...@@ -78,39 +70,44 @@ messages = [{ ...@@ -78,39 +70,44 @@ messages = [{
"Can you tell me what the temperate will be in Dallas, in fahrenheit?" "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}] }]
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools)
print("Chat completion results:") def get_current_weather(city: str, state: str, unit: 'str'):
print(chat_completion) return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
print("\n\n") "partly cloudly, with highs in the 90's.")
tool_calls_stream = client.chat.completions.create(messages=messages,
def handle_tool_calls_stream(
client: OpenAI,
messages: list[dict[str, str]],
model: str,
tools: list[dict[str, Any]],
) -> list[Any]:
tool_calls_stream = client.chat.completions.create(messages=messages,
model=model, model=model,
tools=tools, tools=tools,
stream=True) stream=True)
chunks = []
chunks = [] print("chunks: ")
for chunk in tool_calls_stream: for chunk in tool_calls_stream:
chunks.append(chunk) chunks.append(chunk)
if chunk.choices[0].delta.tool_calls: if chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls[0]) print(chunk.choices[0].delta.tool_calls[0])
else: else:
print(chunk.choices[0].delta) print(chunk.choices[0].delta)
return chunks
arguments = []
tool_call_idx = -1
for chunk in chunks:
def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
arguments = []
tool_call_idx = -1
print("arguments: ")
for chunk in chunks:
if chunk.choices[0].delta.tool_calls: if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0] tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx: if tool_call.index != tool_call_idx:
if tool_call_idx >= 0: if tool_call_idx >= 0:
print( print(f"streamed tool call arguments: "
f"streamed tool call arguments: {arguments[tool_call_idx]}" f"{arguments[tool_call_idx]}")
)
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("") arguments.append("")
if tool_call.id: if tool_call.id:
...@@ -118,36 +115,63 @@ for chunk in chunks: ...@@ -118,36 +115,63 @@ for chunk in chunks:
if tool_call.function: if tool_call.function:
if tool_call.function.name: if tool_call.function.name:
print(f"streamed tool call name: {tool_call.function.name}") print(
f"streamed tool call name: {tool_call.function.name}")
if tool_call.function.arguments: if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments arguments[tool_call_idx] += tool_call.function.arguments
if len(arguments): return arguments
print(f"streamed tool call arguments: {arguments[-1]}")
print("\n\n")
messages.append({ def main():
"role": "assistant", # Initialize OpenAI client
"tool_calls": chat_completion.choices[0].message.tool_calls client = OpenAI(
}) # defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
# Get available models and select one
models = client.models.list()
model = models.data[0].id
# Now, simulate a tool call chat_completion = client.chat.completions.create(messages=messages,
def get_current_weather(city: str, state: str, unit: 'str'): model=model,
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " tools=tools)
"partly cloudly, with highs in the 90's.")
print("-" * 70)
print("Chat completion results:")
print(chat_completion)
print("-" * 70)
# Stream tool calls
chunks = handle_tool_calls_stream(client, messages, model, tools)
print("-" * 70)
# Handle arguments from streamed tool calls
arguments = handle_tool_calls_arguments(chunks)
available_tools = {"get_current_weather": get_current_weather} if len(arguments):
print(f"streamed tool call arguments: {arguments[-1]}\n")
completion_tool_calls = chat_completion.choices[0].message.tool_calls print("-" * 70)
for call in completion_tool_calls:
# Add tool call results to the conversation
messages.append({
"role": "assistant",
"tool_calls": chat_completion.choices[0].message.tool_calls
})
# Now, simulate a tool call
available_tools = {"get_current_weather": get_current_weather}
completion_tool_calls = chat_completion.choices[0].message.tool_calls
for call in completion_tool_calls:
tool_to_call = available_tools[call.function.name] tool_to_call = available_tools[call.function.name]
args = json.loads(call.function.arguments) args = json.loads(call.function.arguments)
result = tool_to_call(**args) result = tool_to_call(**args)
print(result) print("tool_to_call result: ", result)
messages.append({ messages.append({
"role": "tool", "role": "tool",
"content": result, "content": result,
...@@ -155,9 +179,14 @@ for call in completion_tool_calls: ...@@ -155,9 +179,14 @@ for call in completion_tool_calls:
"name": call.function.name "name": call.function.name
}) })
chat_completion_2 = client.chat.completions.create(messages=messages, chat_completion_2 = client.chat.completions.create(messages=messages,
model=model, model=model,
tools=tools, tools=tools,
stream=False) stream=False)
print("\n\n") print("Chat completion2 results:")
print(chat_completion_2) print(chat_completion_2)
print("-" * 70)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment