chat_with_tools.py 4.42 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

zhuwenwen's avatar
zhuwenwen committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# ruff: noqa
import json
import random
import string

from vllm import LLM
from vllm.sampling_params import SamplingParams

# This script is an offline demo for function calling
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
#     "model": "mistralai/Mistral-7B-Instruct-v0.3"
#     "messages": [
#       {
#         "role": "user",
#         "content": [
#             {"type" : "text", "text": "Describe this image in detail please."},
#             {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
#             {"type" : "text", "text": "and this one as well. Answer in French."},
#             {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
#         ]
#       }
#     ]
#   }'
# ```
#
# Usage:
#     python demo.py simple
#     python demo.py advanced

model_name = "mistralai/Mistral-7B-Instruct-v0.3"
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
# or "mistralai/Mistral-Large-Instruct-2407"
# or any other mistral model with function calling ability

sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
54
55
56
57
58
59
llm = LLM(
    model=model_name,
    tokenizer_mode="mistral",
    config_format="mistral",
    load_format="mistral",
)
zhuwenwen's avatar
zhuwenwen committed
60
61
62
63


def generate_random_id(length=9):
    characters = string.ascii_letters + string.digits
64
    random_id = "".join(random.choice(characters) for _ in range(length))
zhuwenwen's avatar
zhuwenwen committed
65
66
67
68
    return random_id


# simulate an API that can be called
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def get_current_weather(city: str, state: str, unit: "str"):
    return (
        f"The weather in {city}, {state} is 85 degrees {unit}. It is "
        "partly cloudly, with highs in the 90's."
    )


tool_functions = {"get_current_weather": get_current_weather}

tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The city to find the weather for, e.g. 'San Francisco'",
                    },
                    "state": {
                        "type": "string",
                        "description": "the two-letter abbreviation for the state that the city is"
                        " in, e.g. 'CA' which would mean 'California'",
                    },
                    "unit": {
                        "type": "string",
                        "description": "The unit to fetch the temperature in",
                        "enum": ["celsius", "fahrenheit"],
                    },
zhuwenwen's avatar
zhuwenwen committed
101
                },
102
                "required": ["city", "state", "unit"],
zhuwenwen's avatar
zhuwenwen committed
103
            },
104
        },
zhuwenwen's avatar
zhuwenwen committed
105
    }
106
]
zhuwenwen's avatar
zhuwenwen committed
107

108
109
110
111
112
113
messages = [
    {
        "role": "user",
        "content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?",
    }
]
zhuwenwen's avatar
zhuwenwen committed
114
115
116
117
118

outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip()

# append the assistant message
119
120
121
122
123
124
messages.append(
    {
        "role": "assistant",
        "content": output,
    }
)
zhuwenwen's avatar
zhuwenwen committed
125
126
127
128
129

# let's now actually parse and execute the model's output simulating an API call by using the
# above defined function
tool_calls = json.loads(output)
tool_answers = [
130
    tool_functions[call["name"]](**call["arguments"]) for call in tool_calls
zhuwenwen's avatar
zhuwenwen committed
131
132
133
]

# append the answer as a tool message and let the LLM give you an answer
134
135
136
137
138
139
140
messages.append(
    {
        "role": "tool",
        "content": "\n\n".join(tool_answers),
        "tool_call_id": generate_random_id(),
    }
)
zhuwenwen's avatar
zhuwenwen committed
141
142
143
144
145
146
147

outputs = llm.chat(messages, sampling_params, tools=tools)

print(outputs[0].outputs[0].text.strip())
# yields
#   'The weather in Dallas, TX is 85 degrees fahrenheit. '
#   'It is partly cloudly, with highs in the 90's.'