vllm_offline.py 830 Bytes
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from vllm import LLM, SamplingParams


def inference(model_path):
    
    messages = [
        {"role": "user", "content": "孩子咳嗽老不好怎么办?"}
    ]
    
    sampling_params = SamplingParams(temperature=0.1, 
                                     top_p=0.95,
                                     max_tokens=512)
    
    llm = LLM(model=model_path)
    
    outputs = llm.chat(messages, sampling_params)

    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
    from argparse import ArgumentParser
    
    parser = ArgumentParser()
    
    parser.add_argument("--model_path", type=str)

    args = parser.parse_args()
    
    inference(args.model_path)