infer_vllm.py 787 Bytes
Newer Older
chenych's avatar
Update  
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import argparse

from vllm import LLM, SamplingParams

parse = argparse.ArgumentParser()
parse.add_argument("--user_prompt", type=str, default="Explain Machine Learning to me in a nutshell.")
parse.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-Instruct-v0.3")

args = parse.parse_args()

sampling_params = SamplingParams(max_tokens=8192)

# If you want to divide the GPU requirement over multiple devices, please add *e.g.* `tensor_parallel=2`
llm = LLM(model=args.model_name_or_path, tokenizer_mode="mistral", config_format="mistral", load_format="mistral")

messages = [
    {
        "role": "user",
        "content": args.user_prompt
    },
]

outputs = llm.chat(messages, sampling_params=sampling_params)
print("output:", outputs[0].outputs[0].text)