benchmark_long_document_qa_throughput.py 6.16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""
Offline benchmark to test the long document QA throughput.

Example usage:
6
    # This workload samples 8 different prompts with a default input
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    # length of 20000 tokens, then replicates each prompt 2 times 
    # in random order.
    python benchmark_long_document_qa_throughput.py \
        --model meta-llama/Llama-2-7b-chat-hf \
        --enable-prefix-caching \
        --num-documents 8 \
        --repeat-count 2 

Commandline arguments:
    --num-documents: The number of documents to sample prompts from.

    --document-length: The length of each document in tokens. 
                       (Optional, default: 20000)

    --output-len: The number of tokens to generate for each prompt.
                  (Optional, default: 10)

    --repeat-count: The number of times to repeat each prompt.
                    (Optional, default: 2)

    --repeat-mode: The mode to repeat prompts. The supported modes are:
        - 'random': shuffle the prompts randomly. (Default)
        - 'tile': the entire prompt list is repeated in sequence. (Potentially
                  lowest cache hit)
        - 'interleave': each prompt is repeated consecutively before 
                        moving to the next element. (Highest cache hit)
    
    --shuffle-seed: Random seed when the repeat mode is "random".
                    (Optional, default: 0)

In the meantime, it also supports all the vLLM engine args to initialize the 
LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
details.
"""

import dataclasses
import random
import time

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser


def test_long_document_qa(llm=None, sampling_params=None, prompts=None):
    """
    Test long document QA with the given prompts and sampling parameters.
    Print the time spent in processing all the prompts.

    Args:
        llm: The language model used for generating responses.
        sampling_params: Sampling parameter used to generate the response.
        prompts: A list of prompt strings to be processed by the LLM.
    """
    start_time = time.time()
    llm.generate(prompts, sampling_params=sampling_params)
    end_time = time.time()
    print(f"Time to execute all requests: {end_time - start_time:.4f} secs")


def repeat_prompts(prompts, repeat_count, mode: str):
    """
    Repeat each prompt in the list for a specified number of times.
    The order of prompts in the output list depends on the mode.

    Args:
        prompts: A list of prompts to be repeated.
        repeat_count: The number of times each prompt is repeated.
        mode: The mode of repetition. Supported modes are:
            - 'random': Shuffle the prompts randomly after repetition.
            - 'tile': Repeat the entire prompt list in sequence.
              Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
79
            - 'interleave': Repeat each prompt consecutively before moving to
80
81
82
83
84
85
86
87
88
              the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].

    Returns:
        A list of repeated prompts in the specified order.

    Raises:
        ValueError: If an invalid mode is provided.
    """
    print("Repeat mode: ", mode)
89
    if mode == "random":
90
91
92
        repeated_prompts = prompts * repeat_count
        random.shuffle(repeated_prompts)
        return repeated_prompts
93
    elif mode == "tile":
94
        return prompts * repeat_count
95
    elif mode == "interleave":
96
97
98
99
100
        repeated_prompts = []
        for prompt in prompts:
            repeated_prompts.extend([prompt] * repeat_count)
        return repeated_prompts
    else:
101
102
103
        raise ValueError(
            f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
        )
104
105
106
107
108
109
110
111
112


def main(args):
    random.seed(args.shuffle_seed)

    # Prepare the prompts:
    # we append the document id at the beginning to avoid any of the document
    # being the prefix of other documents
    prompts = [
113
        str(i) + " ".join(["hi"] * args.document_length)
114
115
116
117
118
119
        for i in range(args.num_documents)
    ]

    prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)

    warmup_prompts = [
120
121
122
        "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
        for i in range(args.num_documents)
    ]
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    # Create the LLM engine
    engine_args = EngineArgs.from_cli_args(args)
    llm = LLM(**dataclasses.asdict(engine_args))
    sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

    print("------warm up------")
    test_long_document_qa(
        llm=llm,
        prompts=warmup_prompts,
        sampling_params=sampling_params,
    )

    print("------start generating------")
    test_long_document_qa(
        llm=llm,
        prompts=prompts,
        sampling_params=sampling_params,
    )


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
146
147
148
        description="Benchmark the performance with or "
        "without automatic prefix caching."
    )
149
150

    parser.add_argument(
151
        "--document-length",
152
153
154
155
        type=int,
        # Roughly the number of tokens for a system paper,
        # excluding images
        default=20000,
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        help="Range of input lengths for sampling prompts, "
        'specified as "min:max" (e.g., "128:256").',
    )

    parser.add_argument(
        "--num-documents",
        type=int,
        default=8,
        help="Range of input lengths for sampling prompts, "
        'specified as "min:max" (e.g., "128:256").',
    )

    parser.add_argument("--output-len", type=int, default=10)

    parser.add_argument(
        "--repeat-count",
        type=int,
        default=2,
        help="Number of times to repeat each prompt",
    )

    parser.add_argument(
        "--repeat-mode",
        type=str,
        default="random",
        help="The mode to repeat prompts. The supported "
        'modes are "random", "tile", and "interleave". '
        "See repeat_prompts() in the source code for details.",
    )

    parser.add_argument(
        "--shuffle-seed",
        type=int,
        default=0,
        help='Random seed when the repeat mode is "random"',
    )
192
193
194
195

    parser = EngineArgs.add_cli_args(parser)
    args = parser.parse_args()
    main(args)