conversation.py 7.27 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
import json
import re
from dataclasses import dataclass
from datetime import datetime
from enum import Enum, auto

import streamlit as st
from PIL.Image import Image
Rayyyyy's avatar
Rayyyyy committed
9
from streamlit.delta_generator import DeltaGenerator
Rayyyyy's avatar
Rayyyyy committed
10
11
from tools.browser import Quote, quotes

Rayyyyy's avatar
Rayyyyy committed
12

Rayyyyy's avatar
Rayyyyy committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
QUOTE_REGEX = re.compile(r"【(\d+)†(.+?)】")

SELFCOG_PROMPT = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
DATE_PROMPT = "当前日期: %Y-%m-%d"
TOOL_SYSTEM_PROMPTS = {
    "python": "当你向 `python` 发送包含 Python 代码的消息时,该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出,或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中,`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用,这些在线内容的访问将不会成功。",
    "simple_browser": "你可以使用 `simple_browser` 工具。该工具支持以下函数:\n`search(query: str, recency_days: int)`:使用搜索引擎进行查询并显示结果,可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`:获取一系列指定 id 的页面内容。每次调用时,须选择3-10个页面。选择多个角度的页面,同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的,你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`:打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤:1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL,也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息,也可以再次使用 `search` 进行搜索。",
    "cogview": "如果用户的请求中包含了对图像的描述,你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述,规则:\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求,需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明,所在地为中国,持有中国立场并遵循中国社会主义价值观。",
}

FILE_TEMPLATE = "[File Name]\n{file_name}\n[File Content]\n{file_content}"


def build_system_prompt(
    enabled_tools: list[str],
    functions: list[dict],
):
    value = SELFCOG_PROMPT
    value += "\n\n" + datetime.now().strftime(DATE_PROMPT)
Rayyyyy's avatar
Rayyyyy committed
32
33
    if enabled_tools or functions:
        value += "\n\n# 可用工具"
Rayyyyy's avatar
Rayyyyy committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    contents = []
    for tool in enabled_tools:
        contents.append(f"\n\n## {tool}\n\n{TOOL_SYSTEM_PROMPTS[tool]}")
    for function in functions:
        content = f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
        content += "\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
        contents.append(content)
    value += "".join(contents)
    return value


def response_to_str(response: str | dict[str, str]) -> str:
    """
    Convert response to string.
    """
    if isinstance(response, dict):
        return response.get("name", "") + response.get("content", "")
    return response


class Role(Enum):
    SYSTEM = auto()
    USER = auto()
    ASSISTANT = auto()
    TOOL = auto()
    OBSERVATION = auto()

    def __str__(self):
        match self:
            case Role.SYSTEM:
                return "<|system|>"
            case Role.USER:
                return "<|user|>"
            case Role.ASSISTANT | Role.TOOL:
                return "<|assistant|>"
            case Role.OBSERVATION:
                return "<|observation|>"

    # Get the message block for the given role
    def get_message(self):
        # Compare by value here, because the enum object in the session state
        # is not the same as the enum cases here, due to streamlit's rerunning
        # behavior.
        match self.value:
            case Role.SYSTEM.value:
                return
            case Role.USER.value:
                return st.chat_message(name="user", avatar="user")
            case Role.ASSISTANT.value:
                return st.chat_message(name="assistant", avatar="assistant")
            case Role.TOOL.value:
                return st.chat_message(name="tool", avatar="assistant")
            case Role.OBSERVATION.value:
                return st.chat_message(name="observation", avatar="assistant")
            case _:
                st.error(f"Unexpected role: {self}")


@dataclass
class Conversation:
    role: Role
    content: str | dict
    # Processed content
    saved_content: str | None = None
    metadata: str | None = None
    image: str | Image | None = None

    def __str__(self) -> str:
        metadata_str = self.metadata if self.metadata else ""
        return f"{self.role}{metadata_str}\n{self.content}"

    # Human readable format
    def get_text(self) -> str:
        text = self.saved_content or self.content
        match self.role.value:
            case Role.TOOL.value:
                text = f"Calling tool `{self.metadata}`:\n\n```python\n{text}\n```"
            case Role.OBSERVATION.value:
                text = f"```python\n{text}\n```"
        return text

    # Display as a markdown block
    def show(self, placeholder: DeltaGenerator | None = None) -> str:
        if placeholder:
            message = placeholder
        else:
            message = self.role.get_message()

        if self.image:
            message.image(self.image, width=512)

        if self.role == Role.OBSERVATION:
            metadata_str = f"from {self.metadata}" if self.metadata else ""
            message = message.expander(f"Observation {metadata_str}")

        text = self.get_text()
        if self.role != Role.USER:
            show_text = text
        else:
Rayyyyy's avatar
Rayyyyy committed
133
            splitted = text.split("files uploaded.\n")
Rayyyyy's avatar
Rayyyyy committed
134
135
136
137
138
139
            if len(splitted) == 1:
                show_text = text
            else:
                # Show expander for document content
                doc = splitted[0]
                show_text = splitted[-1]
Rayyyyy's avatar
Rayyyyy committed
140
                expander = message.expander("File Content")
Rayyyyy's avatar
Rayyyyy committed
141
142
143
144
145
                expander.markdown(doc)
        message.markdown(show_text)


def postprocess_text(text: str, replace_quote: bool) -> str:
Rayyyyy's avatar
Rayyyyy committed
146
147
148
149
    text = text.replace(r"\(", "$")
    text = text.replace(r"\)", "$")
    text = text.replace(r"\[", "$$")
    text = text.replace(r"\]", "$$")
Rayyyyy's avatar
Rayyyyy committed
150
151
152
153
154
155
156
157
158
159
160
    text = text.replace("<|assistant|>", "")
    text = text.replace("<|observation|>", "")
    text = text.replace("<|system|>", "")
    text = text.replace("<|user|>", "")
    text = text.replace("<|endoftext|>", "")

    # Replace quotes
    if replace_quote:
        for match in QUOTE_REGEX.finditer(text):
            quote_id = match.group(1)
            quote = quotes.get(quote_id, Quote("未找到引用内容", ""))
Rayyyyy's avatar
Rayyyyy committed
161
            text = text.replace(match.group(0), f" (来源:[{quote.title}]({quote.url})) ")
Rayyyyy's avatar
Rayyyyy committed
162
163

    return text.strip()