demo_tool.py 7.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import re
import yaml
from yaml import YAMLError

import streamlit as st
from streamlit.delta_generator import DeltaGenerator

from client import get_client
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools

12
13
14
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
EXAMPLE_TOOL = {
    "name": "get_current_weather",
    "description": "Get the current weather in a given location",
    "parameters": {
        "type": "object",
        "properties": {
            "location": {
                "type": "string",
                "description": "The city and state, e.g. San Francisco, CA",
            },
            "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
        },
        "required": ["location"],
    }
}

client = get_client()

def tool_call(*args, **kwargs) -> dict:
34
    print("=== Tool call:")
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    print(args)
    print(kwargs)
    st.session_state.calling_tool = True
    return kwargs

def yaml_to_dict(tools: str) -> list[dict] | None:
    try:
        return yaml.safe_load(tools)
    except YAMLError:
        return None

def extract_code(text: str) -> str:
    pattern = r'```([^\n]*)\n(.*?)```'
    matches = re.findall(pattern, text, re.DOTALL)
    return matches[-1][1]

# Append a conversation into history, while show it in a new markdown block
def append_conversation(
53
54
55
    conversation: Conversation,
    history: list[Conversation],
    placeholder: DeltaGenerator | None=None,
56
57
58
59
) -> None:
    history.append(conversation)
    conversation.show(placeholder)

60
def main(top_p: float, temperature: float, prompt_text: str, repetition_penalty: float):
61
    manual_mode = st.toggle('Manual mode',
62
63
        help='Define your tools in YAML format. You need to supply tool call results manually.'
    )
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    if manual_mode:
        with st.expander('Tools'):
            tools = st.text_area(
                'Define your tools in YAML format here:',
                yaml.safe_dump([EXAMPLE_TOOL], sort_keys=False),
                height=400,
            )
        tools = yaml_to_dict(tools)

        if not tools:
            st.error('YAML format error in tools definition')
    else:
        tools = get_tools()

    if 'tool_history' not in st.session_state:
        st.session_state.tool_history = []
    if 'calling_tool' not in st.session_state:
        st.session_state.calling_tool = False

84
    history: list[Conversation] = st.session_state.tool_history
85
86
87
88
89
90
91
92
93
94

    for conversation in history:
        conversation.show()

    if prompt_text:
        prompt_text = prompt_text.strip()
        role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
        append_conversation(Conversation(role, prompt_text), history)
        st.session_state.calling_tool = False

95
96
97
98
99
100
101
102
103
104
        input_text = preprocess_text(
            None,
            tools,
            history,
        )
        print("=== Input:")
        print(input_text)
        print("=== History:")
        print(history)

105
106
107
108
109
110
111
        placeholder = st.container()
        message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
        markdown_placeholder = message_placeholder.empty()

        for _ in range(5):
            output_text = ''
            for response in client.generate_stream(
112
113
114
115
116
117
118
119
120
                system=None,
                tools=tools,
                history=history,
                do_sample=True,
                max_length=MAX_LENGTH,
                temperature=temperature,
                top_p=top_p,
                stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
                repetition_penalty=repetition_penalty,
121
122
123
            ):
                token = response.token
                if response.token.special:
124
125
126
                    print("=== Output:")
                    print(output_text)

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
                    match token.text.strip():
                        case '<|user|>':
                            append_conversation(Conversation(
                                Role.ASSISTANT,
                                postprocess_text(output_text),
                            ), history, markdown_placeholder)
                            return
                        # Initiate tool call
                        case '<|assistant|>':
                            append_conversation(Conversation(
                                Role.ASSISTANT,
                                postprocess_text(output_text),
                            ), history, markdown_placeholder)
                            output_text = ''
                            message_placeholder = placeholder.chat_message(name="tool", avatar="assistant")
                            markdown_placeholder = message_placeholder.empty()
                            continue
                        case '<|observation|>':
                            tool, *call_args_text = output_text.strip().split('\n')
                            call_args_text = '\n'.join(call_args_text)
147
                            
148
149
150
151
152
153
154
                            append_conversation(Conversation(
                                Role.TOOL,
                                postprocess_text(output_text),
                                tool,
                            ), history, markdown_placeholder)
                            message_placeholder = placeholder.chat_message(name="observation", avatar="user")
                            markdown_placeholder = message_placeholder.empty()
155
                            
156
157
158
159
160
161
                            try:
                                code = extract_code(call_args_text)
                                args = eval(code, {'tool_call': tool_call}, {})
                            except:
                                st.error('Failed to parse tool call')
                                return
162
                            
163
                            output_text = ''
164
                            
165
166
167
168
169
170
171
172
                            if manual_mode:
                                st.info('Please provide tool call results below:')
                                return
                            else:
                                with markdown_placeholder:
                                    with st.spinner(f'Calling tool {tool}...'):
                                        observation = dispatch_tool(tool, args)

173
174
                                if len(observation) > TRUNCATE_LENGTH:
                                    observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                                append_conversation(Conversation(
                                    Role.OBSERVATION, observation
                                ), history, markdown_placeholder)
                                message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
                                markdown_placeholder = message_placeholder.empty()
                                st.session_state.calling_tool = False
                                break
                        case _:
                            st.error(f'Unexpected special token: {token.text.strip()}')
                            return
                output_text += response.token.text
                markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
            else:
                append_conversation(Conversation(
                    Role.ASSISTANT,
                    postprocess_text(output_text),
                ), history, markdown_placeholder)
                return