"vscode:/vscode.git/clone" did not exist on "e39406212b92490099cce6fb5fd7ce2da555bda7"
function_calling.py 3.79 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# encoding: utf-8
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from minicpm_tool_parser import fc2dict
import json

model_path = "openbmb/MiniCPM3-4B"

tools = [
    {
        "type": "function",
        "function": {
            "name": "get_delivery_date",
            "description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
            "parameters": {
                "type": "object",
                "properties": {
                    "order_id": {
                        "type": "string",
                        "description": "The customer's order ID.",
                    },
                },
                "required": ["order_id"],
                "additionalProperties": False,
            },
        },
    }
]
messages = [
    {
        "role": "system",
        "content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
    },
    {
        "role": "user",
        "content": "Hi, can you tell me the delivery date for my order? The order id is 1234 and 4321.",
    },
    # {
    #    "content": "",
    #    "tool_calls": [
    #        {
    #            "type": "function",
    #            "function": {
    #                "name": "get_delivery_date",
    #                "arguments": {"order_id": "1234"},
    #            },
    #            "id": "call_b4ab0b4ec4b5442e86f017fe0385e22e",
    #        },
    #        {
    #            "type": "function",
    #            "function": {
    #                "name": "get_delivery_date",
    #                "arguments": {"order_id": "4321"},
    #            },
    #            "id": "call_628965479dd84794bbb72ab9bdda0c39",
    #        },
    #    ],
    #    "role": "assistant",
    # },
    # {
    #    "role": "tool",
    #    "content": '{"delivery_date": "2024-09-05", "order_id": "1234"}',
    #    "tool_call_id": "call_b4ab0b4ec4b5442e86f017fe0385e22e",
    # },
    # {
    #    "role": "tool",
    #    "content": '{"delivery_date": "2024-09-05", "order_id": "4321"}',
    #    "tool_call_id": "call_628965479dd84794bbb72ab9bdda0c39",
    # },
    # {
    #    "content": "Both your orders will be delivered on 2024-09-05.",
    #    "role": "assistant",
    #    "thought": "\nI have the information you need, both orders will be delivered on the same date, 2024-09-05.\n",
    # },
]
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
    messages, tools=tools, tokenize=False, add_generation_prompt=True
)
llm = LLM(model_path, trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1000)


def fake_tool_execute(toolcall):
    data = {
        "delivery_date": "2024-09-05",
        "order_id": toolcall.get("function", {})
        .get("arguments", {})
        .get("order_id", "order_id"),
    }
    return json.dumps(data)


while True:
    prompt = tokenizer.apply_chat_template(
        messages, tools=tools, tokenize=False, add_generation_prompt=True
    )
    outputs = llm.generate([prompt], sampling_params)
    response = outputs[0].outputs[0].text
    msg = fc2dict(response)
    if (
        "tool_calls" in msg
        and msg["tool_calls"] is not None
        and len(msg["tool_calls"]) > 0
    ):
        messages.append(msg)
        print(msg)
        for toolcall in msg["tool_calls"]:
            tool_response = fake_tool_execute(toolcall)
            tool_msg = {
                "role": "tool",
                "content": tool_response,
                "tool_call_id": toolcall["id"],
            }
            messages.append(tool_msg)
            print(tool_msg)
    else:
        messages.append(msg)
        print(msg)
        break