vllm_test.py 4.13 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
import os
Rayyyyy's avatar
Rayyyyy committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import time
import torch
import argparse

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from vllm import LLM, SamplingParams


def infer_hf_chatglm(model_path, prompt):
    '''transformers 推理 chatglm2'''
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto").half().cuda()
    model = model.eval()
    start_time = time.time()
    generated_text, _ = model.chat(tokenizer, prompt, history=[])
    print("chat time ", time.time()- start_time)
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    return generated_text

Rayyyyy's avatar
Rayyyyy committed
21

Rayyyyy's avatar
Rayyyyy committed
22
23
24
25
26
27
def infer_hf_llama3(model_path, prompt):
    '''transformers 推理 llama3'''
    input_query = {"role": "user", "content": prompt}
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path, torch_dtype="auto", device_map="auto")
Rayyyyy's avatar
Rayyyyy committed
28

Rayyyyy's avatar
Rayyyyy committed
29
30
    input_ids = tokenizer.apply_chat_template(
        [input_query,], add_generation_prompt=True, return_tensors="pt").to(model.device)
Rayyyyy's avatar
Rayyyyy committed
31

Rayyyyy's avatar
Rayyyyy committed
32
33
34
35
36
37
38
    outputs = model.generate(
        input_ids,
        max_new_tokens=512,
        do_sample=True,
        temperature=1,
        top_p=0.95,
    )
Rayyyyy's avatar
Rayyyyy committed
39

Rayyyyy's avatar
Rayyyyy committed
40
41
    response = outputs[0][input_ids.shape[-1]:]
    generated_text = tokenizer.decode(response, skip_special_tokens=True)
Rayyyyy's avatar
Rayyyyy committed
42
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Rayyyyy's avatar
Rayyyyy committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    return generated_text


def infer_vllm_llama3(model_path, message, tp_size=1, max_model_len=1024):
    '''vllm 推理 llama3'''
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    messages = [{"role": "user", "content": message}]
    print(f"Prompt: {messages!r}")
    sampling_params = SamplingParams(temperature=1,
                                     top_p=0.95,
                                     max_tokens=1024,
                                     stop_token_ids=[tokenizer.eos_token_id])

    llm = LLM(model=model_path,
              max_model_len=max_model_len,
              trust_remote_code=True,
              enforce_eager=True,
              dtype="float16",
              tensor_parallel_size=tp_size)
    # generate answer
    start_time = time.time()
    prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
    outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
    print("total infer time", time.time() - start_time)
    # results
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")


def infer_vllm_chatglm(model_path, message, tp_size=1):
    '''vllm 推理 chatglm2'''
    sampling_params = SamplingParams(temperature=1.0,
                                     top_p=0.9,
                                     max_tokens=1024)

    llm = LLM(model=model_path,
              trust_remote_code=True,
              enforce_eager=True,
              dtype="float16",
              tensor_parallel_size=tp_size)
    # generate answer
    print(f"chatglm2 Prompt: {message!r}")
    outputs = llm.generate(message, sampling_params=sampling_params)
    # results
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")

Rayyyyy's avatar
Rayyyyy committed
94

Rayyyyy's avatar
Rayyyyy committed
95
96
97
98
99
100
101
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', default='')
    parser.add_argument('--query', default="DCU是什么?", help='提问的问题.')
    parser.add_argument('--use_hf', action='store_true')
    args = parser.parse_args()
    return args
Rayyyyy's avatar
Rayyyyy committed
102

Rayyyyy's avatar
Rayyyyy committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
if __name__ == "__main__":
    args = parse_args()
    is_llama = True if "llama" in args.model_path else False
    print("Is llama", is_llama)
    if args.use_hf:
        # transformers
        if is_llama:
            infer_hf_llama3(args.model_path, args.query)
        else:
            infer_hf_chatglm(args.model_path, args.query)
    else:
        # vllm
        if is_llama:
            infer_vllm_llama3(args.model_path, args.query)
        else:
            infer_vllm_chatglm(args.model_path, args.query)