cli_demo.py 2.99 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- utf8 -*

import argparse
from conversation import get_conv_template

try:
    from vllm import LLM, SamplingParams
    is_vllm_avaiable = True
    print("use vllm.generate to infer...")
except ImportError:
    from transformers import LlamaForCausalLM, LlamaTokenizer
    is_vllm_avaiable = False
    print("use transformers.generate to infer...")


def infer_vllm(llm, sampling_params, prompt):
    assert llm is not None
    assert sampling_params is not None
    generation = llm.generate(prompt, sampling_params, use_tqdm=False)
    outputs = generation[0].outputs[0].text.strip()
    return outputs


def infer(model, tokenizer, prompt):
    assert model is not None
    assert tokenizer is not None
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        **inputs,
        max_new_tokens=args.max_new_tokens,
        do_sample=True,
        temperature=args.temperature,
        top_p=args.top_p
    )
    outputs = tokenizer.decode(outputs.cpu()[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()
    return outputs


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test XuanYuan-70B-chat with vLLM")
    parser.add_argument("-c", "--checkpoint_path", type=str, help="Checkpoint path")
    parser.add_argument("-n", "--max_new_tokens", type=int, default=1000)
    parser.add_argument("-t", "--temperature", type=float, default=0.95)
    parser.add_argument("-p", "--top_p", type=float, default=0.95)
    args = parser.parse_args()

    llm = None
    sampling_params = None
    model = None
    tokenizer = None

    if is_vllm_avaiable:
        print("loading weight with vLLM...")
        sampling_params = SamplingParams(
            temperature=args.temperature,
            top_p=args.top_p,
            stop=list(["</s>"]),
            max_tokens=args.max_new_tokens
        )
        llm = LLM(args.checkpoint_path, tensor_parallel_size=8)
    else:
        print("loading weight with transformers ...")
        tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint_path, use_fast=False, legacy=True)
        model = LlamaForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto")

    conv = get_conv_template("XuanYuan-Chat")
    print("########")
    print("输入为: EXIT!! 表示退出")
    print("输入为: CLEAR!! 表示清空上下文")
    print("########")
    while True:
        content = input("输入: ")
        if content.strip() == "EXIT!!":
            print("exit....")
            break
        if content.strip() == "CLEAR!!":
            conv = get_conv_template("XuanYuan-Chat")
            print("clear...")
            continue

        conv.append_message(conv.roles[0], content.strip())
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        if is_vllm_avaiable:
            outputs = infer_vllm(llm, sampling_params, prompt)
        else:
            outputs = infer(model, tokenizer, prompt)
        print(f"输出: {outputs}")
        conv.update_last_message(outputs)