trans_cli_demo.py 3.17 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
"""
Rayyyyy's avatar
Rayyyyy committed
2
This script creates a CLI demo with transformers backend for the glm-4-9b-chat model,
Rayyyyy's avatar
Rayyyyy committed
3
4
5
6
7
8
9
10
allowing users to interact with the model through a command-line interface.

Usage:
- Run the script to start the CLI demo.
- Interact with the model by typing questions and receiving responses.

Note: The script includes a modification to handle markdown to plain text conversion,
ensuring that the CLI interface displays formatted text correctly.
Rayyyyy's avatar
Rayyyyy committed
11
12
13

If you use flash attention, you should install the flash-attn and  add attn_implementation="flash_attention_2" in model loading.

Rayyyyy's avatar
Rayyyyy committed
14
15
"""

Rayyyyy's avatar
Rayyyyy committed
16
from threading import Thread
Rayyyyy's avatar
Rayyyyy committed
17

Rayyyyy's avatar
Rayyyyy committed
18
19
import torch
from transformers import (
Rayyyyy's avatar
Rayyyyy committed
20
    AutoModelForCausalLM,
Rayyyyy's avatar
Rayyyyy committed
21
22
23
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
Rayyyyy's avatar
Rayyyyy committed
24
    TextIteratorStreamer,
Rayyyyy's avatar
Rayyyyy committed
25
26
27
)


Rayyyyy's avatar
Rayyyyy committed
28
MODEL_PATH = "THUDM/GLM-4-9B-0414"
Rayyyyy's avatar
Rayyyyy committed
29

Rayyyyy's avatar
Rayyyyy committed
30
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
Rayyyyy's avatar
Rayyyyy committed
31

Rayyyyy's avatar
Rayyyyy committed
32
model = AutoModelForCausalLM.from_pretrained(
Rayyyyy's avatar
Rayyyyy committed
33
    MODEL_PATH,
Rayyyyy's avatar
Rayyyyy committed
34
    torch_dtype=torch.bfloat16,
Rayyyyy's avatar
Rayyyyy committed
35
    device_map="auto",
Rayyyyy's avatar
Rayyyyy committed
36
).eval()
Rayyyyy's avatar
Rayyyyy committed
37
38
39
40
41
42
43
44
45
46
47
48
49


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = model.config.eos_token_id
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


if __name__ == "__main__":
    history = []
Rayyyyy's avatar
Rayyyyy committed
50
    max_length = 8192
Rayyyyy's avatar
Rayyyyy committed
51
52
53
    top_p = 0.8
    temperature = 0.6
    stop = StopOnTokens()
Rayyyyy's avatar
Rayyyyy committed
54

Rayyyyy's avatar
Rayyyyy committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    print("Welcome to the GLM-4-9B CLI chat. Type your messages below.")
    while True:
        user_input = input("\nYou: ")
        if user_input.lower() in ["exit", "quit"]:
            break
        history.append([user_input, ""])

        messages = []
        for idx, (user_msg, model_msg) in enumerate(history):
            if idx == len(history) - 1 and not model_msg:
                messages.append({"role": "user", "content": user_msg})
                break
            if user_msg:
                messages.append({"role": "user", "content": user_msg})
            if model_msg:
                messages.append({"role": "assistant", "content": model_msg})
        model_inputs = tokenizer.apply_chat_template(
Rayyyyy's avatar
Rayyyyy committed
72
            messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
Rayyyyy's avatar
Rayyyyy committed
73
        ).to(model.device)
Rayyyyy's avatar
Rayyyyy committed
74
        streamer = TextIteratorStreamer(tokenizer=tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
Rayyyyy's avatar
Rayyyyy committed
75
        generate_kwargs = {
Rayyyyy's avatar
Rayyyyy committed
76
77
            "input_ids": model_inputs["input_ids"],
            "attention_mask": model_inputs["attention_mask"],
Rayyyyy's avatar
Rayyyyy committed
78
79
80
81
82
83
84
            "streamer": streamer,
            "max_new_tokens": max_length,
            "do_sample": True,
            "top_p": top_p,
            "temperature": temperature,
            "stopping_criteria": StoppingCriteriaList([stop]),
            "repetition_penalty": 1.2,
Rayyyyy's avatar
Rayyyyy committed
85
            "eos_token_id": model.config.eos_token_id,
Rayyyyy's avatar
Rayyyyy committed
86
87
88
89
90
91
92
93
94
95
        }
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()
        print("GLM-4:", end="", flush=True)
        for new_token in streamer:
            if new_token:
                print(new_token, end="", flush=True)
                history[-1][1] += new_token

        history[-1][1] = history[-1][1].strip()