cli_demo_chatglm3.py 1.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
使用python的分词构造指令方式实现chatglm3,需要修改hf_model.py中的chatglm3模型转换时model.config.model_type的赋值实现,不推荐外部使用
"""
import argparse
from fastllm_pytools import llm
import time
from transformers import AutoTokenizer, AutoModel

def args_parser():
    parser = argparse.ArgumentParser(description = 'fastllm_chat_demo')
    parser.add_argument('-p', '--path', type = str, required = True, default = '', help = '模型文件的路径')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = args_parser()
    model_path = args.path
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
    # model = llm.model(args.path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = llm.from_hf(model, tokenizer, dtype = "float16")

    history = []
    print("输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
    while True:
        query = input("\n用户:")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            history = []
            print("输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
            continue
        print("AI:", end = "")
        current_length = 0
        token_count = 0
        t0 = time.time()
        for response, history in model.stream_chat(tokenizer, query, history=history):
            print(response[current_length:], end="", flush=True)
            token_count += 1
            current_length = len(response)

        t1 = time.time()
        print("\ntoken/s: {:.2f}, character/s: {:.2f}".format(token_count/(t1-t0), current_length/(t1-t0)))