"applications/Chat/coati/vscode:/vscode.git/clone" did not exist on "f447ca18111c2e37a2f14e7aecc98876dc7e3216"
stream_chat_example.py 2.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import argparse

from transformers import AutoTokenizer, AutoModelForCausalLM
from colossal_llama2.utils.stream_chat_patch import streaming_chat

SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."

def main(args):
    model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval()
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)

    past_key_values, history = None, []
    roles = ["", "Human", "Assistant"]

    history = []
    history.append({"role": roles[0], "message": SYSTEM})

    while True:
        input_query = input(f"\n{roles[1]}: ")
        if input_query.strip() == "exit":
            break
        if input_query.strip() == "clear":
            past_key_values, history = None, []
            continue

        print(f"\n{roles[2]}: ", end="")
        gen_len = 0
        for response, history, past_key_values in streaming_chat(
            model, tokenizer, input_query, history=history, roles=roles,
            temperature = args.temperature,
            top_p = args.top_p,
            top_k = args.top_k,
            do_sample = args.do_sample,
            length_penalty = args.length_penalty,
            max_new_tokens = args.max_new_tokens,
            past_key_values=past_key_values,
            return_past_key_values=True):

            output = response[gen_len:]
            print(output, end="", flush=True)
            gen_len = len(response)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default=None, help="path to chat version model")
    parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer")
    parser.add_argument('--temperature', type=float, default=0.8, help="set temperature")
    parser.add_argument('--top_p', type=float, default=0.95, help="set top p value")
    parser.add_argument('--top_k', type=int, default=50, help="set top k value")
    parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not")
    parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty")
    parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens")
    args = parser.parse_args()
    main(args)