llm_engine_example.py 2.22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
This file demonstrates using the `LLMEngine`
for processing prompts with various sampling parameters.
"""
6

7
8
import argparse

9
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
10
from vllm.utils import FlexibleArgumentParser
11
12


13
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
14
15
    """Create a list of test prompts with their sampling parameters."""
    return [
16
17
18
19
20
21
22
23
24
25
26
27
        (
            "A robot may not injure a human being",
            SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1),
        ),
        (
            "To be or not to be,",
            SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2),
        ),
        (
            "What is the meaning of life?",
            SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1),
        ),
28
29
    ]

30

31
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
32
    """Continuously process a list of prompts and handle the outputs."""
33
    request_id = 0
34

35
    print("-" * 50)
36
    while test_prompts or engine.has_unfinished_requests():
37
38
        if test_prompts:
            prompt, sampling_params = test_prompts.pop(0)
Zhuohan Li's avatar
Zhuohan Li committed
39
            engine.add_request(str(request_id), prompt, sampling_params)
Zhuohan Li's avatar
Zhuohan Li committed
40
            request_id += 1
41

42
        request_outputs: list[RequestOutput] = engine.step()
43

44
        for request_output in request_outputs:
45
            if request_output.finished:
46
                print(request_output)
47
                print("-" * 50)
48

49
50
51
52
53
54
55

def initialize_engine(args: argparse.Namespace) -> LLMEngine:
    """Initialize the LLMEngine from the command line arguments."""
    engine_args = EngineArgs.from_cli_args(args)
    return LLMEngine.from_engine_args(engine_args)


56
57
def parse_args():
    parser = FlexibleArgumentParser(
58
59
        description="Demo on using the LLMEngine class directly"
    )
60
61
62
63
    parser = EngineArgs.add_cli_args(parser)
    return parser.parse_args()


64
65
66
67
68
def main(args: argparse.Namespace):
    """Main function that sets up and runs the prompt processing."""
    engine = initialize_engine(args)
    test_prompts = create_test_prompts()
    process_requests(engine, test_prompts)
69
70


71
if __name__ == "__main__":
72
    args = parse_args()
73
    main(args)