infer_vllm.py 2.26 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import time

from openai import OpenAI
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams


def infer_llama4_vllm(model_path, message, tp_size=1, max_model_len=4096):
    '''vllm 推理 llama4'''
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # message = MARKDOWN_TEMPLATE.format(query)
    messages = [{"role": "user", "content": message}]
    print(f"Prompt: {messages!r}")
    sampling_params = SamplingParams(temperature=0.3,
                                     top_p=0.9,
                                     max_tokens=4096,
                                     stop_token_ids=[tokenizer.eos_token_id])

    llm = LLM(model=model_path,
              max_model_len=max_model_len,
              trust_remote_code=True,
              enforce_eager=True,
              dtype="float16",
              tensor_parallel_size=tp_size)
    # generate answer
    prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True)]
    start_time = time.time()
    outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
    print("total infer time", time.time() - start_time)
    # results
    for output in outputs:
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")




def infer_llama4_client(client, messages, model_name='Llama-4-Scout-17B-16E-Instruct'):
    print(f"Prompt: {messages!r}")
    response = client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are a helpful assistant"},
            {"role": "user", "content": str(messages)}
            ],
        model=model_name,
        stream=False
    )
    print(f"Response text: {response!r}")
    return response


if __name__ == "__main__":
    # VLLM 本地推理
    infer_llama4_vllm(model_path="meta-llama/Llama-4-Scout-17B-16E-Instruct",
                      message="你好",
                      tp_size=1,
                      max_model_len=4096)
    # OpenAI API 推理
    # url = "127.0.0.1:8000"  # 根据实际情况修改
    # client = OpenAI(api_key="EMPTY", base_url=f"http://{url}/v1")
    # infer_llama4_client(client=client,
    #                     messages="你好",
    #                     model_name='Llama-4-Scout-17B-16E-Instruct')