"docs/vscode:/vscode.git/clone" did not exist on "61a6905ab036fd00eafdb1b0ca130d5feccfe686"
chat.py 3.32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from vllm import LLM, EngineArgs
5
from vllm.outputs import RequestOutput
6
from vllm.utils.argparse_utils import FlexibleArgumentParser
7
8


9
10
11
def create_parser():
    parser = FlexibleArgumentParser()
    # Add engine args
12
13
    EngineArgs.add_cli_args(parser)
    parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
14
15
16
17
18
19
20
21
22
23
24
25
    # Add sampling params
    sampling_group = parser.add_argument_group("Sampling parameters")
    sampling_group.add_argument("--max-tokens", type=int)
    sampling_group.add_argument("--temperature", type=float)
    sampling_group.add_argument("--top-p", type=float)
    sampling_group.add_argument("--top-k", type=int)
    # Add example params
    parser.add_argument("--chat-template-path", type=str)

    return parser


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def main(args: dict):
    # Pop arguments not used by LLM
    max_tokens = args.pop("max_tokens")
    temperature = args.pop("temperature")
    top_p = args.pop("top_p")
    top_k = args.pop("top_k")
    chat_template_path = args.pop("chat_template_path")

    # Create an LLM
    llm = LLM(**args)

    # Create sampling params object
    sampling_params = llm.get_default_sampling_params()
    if max_tokens is not None:
        sampling_params.max_tokens = max_tokens
    if temperature is not None:
        sampling_params.temperature = temperature
    if top_p is not None:
        sampling_params.top_p = top_p
    if top_k is not None:
        sampling_params.top_k = top_k

48
49
    def print_outputs(outputs: list[RequestOutput], prompts: list):
        assert len(outputs) == len(prompts)
50
        print("\nGenerated Outputs:\n" + "-" * 80)
51
        for i, output in enumerate(outputs):
52
            generated_text = output.outputs[0].text
53
            print(f"Prompt: {prompts[i]!r}\n")
54
            print(f"Generated text: {generated_text!r}")
55
            print("-" * 80)
56
57
58
59
60

    print("=" * 80)

    # In this script, we demonstrate how to pass input to the chat method:
    conversation = [
61
62
63
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": "Hello"},
        {"role": "assistant", "content": "Hello! How can I assist you today?"},
64
65
        {
            "role": "user",
66
            "content": "Write an essay about the importance of higher education.",
67
68
69
        },
    ]
    outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
70
71
72
73
74
75
    print_outputs(
        outputs,
        [
            conversation,
        ],
    )
76
77
78
79
80
81

    # You can run batch inference with llm.chat API
    conversations = [conversation for _ in range(10)]

    # We turn on tqdm progress bar to verify it's indeed running batch inference
    outputs = llm.chat(conversations, sampling_params, use_tqdm=True)
82
    print_outputs(outputs, conversations)
83
84
85
86
87
88
89
90
91
92
93
94
95

    # A chat template can be optionally supplied.
    # If not, the model will use its default chat template.
    if chat_template_path is not None:
        with open(chat_template_path) as f:
            chat_template = f.read()

        outputs = llm.chat(
            conversations,
            sampling_params,
            use_tqdm=False,
            chat_template=chat_template,
        )
96
        print_outputs(outputs, conversations)
97
98
99


if __name__ == "__main__":
100
    parser = create_parser()
101
102
    args: dict = vars(parser.parse_args())
    main(args)