cli_demo.py 4.3 KB
Newer Older
MPU王荣胜's avatar
add cli  
MPU王荣胜 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- encoding: utf-8 -*-

import os
import sys
import torch
import argparse
from transformers import AutoTokenizer
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize

from model import VisualGLMModel, chat
from finetune_XrayGLM import FineTuneVisualGLMModel
from sat.model import AutoModel


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
    parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
    parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling')
    parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
    parser.add_argument("--english", action='store_true', help='only output English')
    parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
    parser.add_argument("--from_pretrained", type=str, default="visualglm-6b", help='pretrained ckpt')
    parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round')
    parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round')
    args = parser.parse_args()

    # load model
    model, model_args = AutoModel.from_pretrained(
        args.from_pretrained,
        args=argparse.Namespace(
        fp16=True,
        skip_init=True,
        use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
        device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
    ))
    model = model.eval()

    if args.quant:
        quantize(model.transformer, args.quant)
    
MPU王荣胜's avatar
MPU王荣胜 committed
43
44
    #if torch.cuda.is_available():
    #    model = model.cuda()
MPU王荣胜's avatar
add cli  
MPU王荣胜 committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    model.add_mixin('auto-regressive', CachedAutoregressiveMixin())

    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    if not args.english:
        print('欢迎使用 XrayGLM 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序')
    else:
        print('Welcome to XrayGLM model. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
    with torch.no_grad():
        while True:
            history = None
            cache_image = None
            if not args.english:
                image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")
            else:
                image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ")

            if image_path == 'stop':
                break
            if len(image_path) > 0:
                query = args.prompt_en if args.english else args.prompt_zh
            else:
                if not args.english:
                    query = input("用户:")
                else:
                    query = input("User: ")
            while True:
                if query == "clear":
                    break
                if query == "stop":
                    sys.exit(0)
                try:
                    response, history, cache_image = chat(
                        image_path, 
                        model, 
                        tokenizer,
                        query, 
                        history=history, 
                        image=cache_image, 
                        max_length=args.max_length, 
                        top_p=args.top_p, 
                        temperature=args.temperature,
                        top_k=args.top_k,
                        english=args.english,
                        invalid_slices=[slice(63823, 130000)] if args.english else []
                        )
                except Exception as e:
                    print(e)
                    break
                sep = 'A:' if args.english else '答:'
                print("XrayGLM:"+response.split(sep)[-1].strip())
                image_path = None
                if not args.english:
                    query = input("用户:")
                else:
                    query = input("User: ")


if __name__ == "__main__":
MPU王荣胜's avatar
MPU王荣胜 committed
104
    main()