cli_demo_tool.py 4.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
This demo script is designed for interacting with the ChatGLM3-6B in Function, to show Function Call capabilities.
"""

import os
import platform
import torch
from transformers import AutoTokenizer, AutoModel

MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False


def build_prompt(history):
    prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
    for query, response in history:
        prompt += f"\n\n用户:{query}"
        prompt += f"\n\nChatGLM3-6B:{response}"
    return prompt


tools = [
    {'name': 'track', 'description': '追踪指定股票的实时价格',
     'parameters':
         {
             'type': 'object', 'properties':
             {'symbol':
                 {
                     'description': '需要追踪的股票代码'
                 }
             },
             'required': []
         }
     }, {
        'name': '/text-to-speech', 'description': '将文本转换为语音',
        'parameters':
            {
                'type': 'object', 'properties':
                {
                    'text':
                        {
                            'description': '需要转换成语音的文本'
                        },
                    'voice':
                        {
                            'description': '要使用的语音类型(男声、女声等)'
                        },
                    'speed': {
                        'description': '语音的速度(快、中等、慢等)'
                    }
                }, 'required': []
            }
    },
    {
        'name': '/image_resizer', 'description': '调整图片的大小和尺寸',
        'parameters': {'type': 'object',
                       'properties':
                           {
                               'image_file':
                                   {
                                       'description': '需要调整大小的图片文件'
                                   },
                               'width':
                                   {
                                       'description': '需要调整的宽度值'
                                   },
                               'height':
                                   {
                                       'description': '需要调整的高度值'
                                   }
                           },
                       'required': []
                       }
    },
    {
        'name': '/foodimg', 'description': '通过给定的食品名称生成该食品的图片',
        'parameters': {
            'type': 'object', 'properties':
                {
                    'food_name':
                        {
                            'description': '需要生成图片的食品名称'
                        }
                },
            'required': []
        }
    }
]
system_item = {
    "role": "system",
    "content": "Answer the following questions as best as you can. You have access to the following tools:",
    "tools": tools
}


def main():
    past_key_values, history = None, [system_item]
    role = "user"
    global stop_stream
    print("欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
    while True:
        query = input("\n用户:") if role == "user" else input("\n结果:")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            past_key_values, history = None, [system_item]
            role = "user"
            os.system(clear_command)
            print("欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
            continue
        print("\nChatGLM:", end="")
        response, history = model.chat(tokenizer, query, history=history, role=role)
        print(response, end="", flush=True)
        print("")
        if isinstance(response, dict):
            role = "observation"
        else:
            role = "user"


if __name__ == "__main__":
    main()