benchmark_latency.py 11.6 KB
Newer Older
1
"""Benchmark the latency of processing a single batch of requests."""
2
import argparse
3
import json
4
import time
5
from pathlib import Path
6
from typing import List, Optional
7
8
9

import numpy as np
import torch
10
from tqdm import tqdm
11

Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm import LLM, SamplingParams
13
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
14
from vllm.inputs import PromptType
15
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
16
from vllm.utils import FlexibleArgumentParser
17
18
19


def main(args: argparse.Namespace):
20
21
22
    print(args)

    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
23
    # the engine will automatically process the request in multiple batches.
24
25
26
27
    llm = LLM(
        model=args.model,
        speculative_model=args.speculative_model,
        num_speculative_tokens=args.num_speculative_tokens,
28
29
        speculative_draft_tensor_parallel_size=\
            args.speculative_draft_tensor_parallel_size,
30
31
32
33
34
        tokenizer=args.tokenizer,
        quantization=args.quantization,
        tensor_parallel_size=args.tensor_parallel_size,
        trust_remote_code=args.trust_remote_code,
        dtype=args.dtype,
35
        max_model_len=args.max_model_len,
36
37
38
39
40
41
42
43
44
45
46
47
        enforce_eager=args.enforce_eager,
        kv_cache_dtype=args.kv_cache_dtype,
        quantization_param_path=args.quantization_param_path,
        device=args.device,
        ray_workers_use_nsight=args.ray_workers_use_nsight,
        enable_chunked_prefill=args.enable_chunked_prefill,
        download_dir=args.download_dir,
        block_size=args.block_size,
        gpu_memory_utilization=args.gpu_memory_utilization,
        load_format=args.load_format,
        distributed_executor_backend=args.distributed_executor_backend,
        otlp_traces_endpoint=args.otlp_traces_endpoint,
48
        enable_prefix_caching=args.enable_prefix_caching,
49
    )
50

Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
    sampling_params = SamplingParams(
        n=args.n,
53
        temperature=1.0,
Woosuk Kwon's avatar
Woosuk Kwon committed
54
        top_p=1.0,
55
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
        max_tokens=args.output_len,
    )
58
    print(sampling_params)
59
60
61
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
62
    dummy_prompts: List[PromptType] = [{
63
64
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
65

66
67
68
69
70
71
72
73
74
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
75
                llm.generate(dummy_prompts,
76
77
78
79
80
                             sampling_params=sampling_params,
                             use_tqdm=False)
            print(p.key_averages())
        else:
            start_time = time.perf_counter()
81
            llm.generate(dummy_prompts,
82
83
84
85
86
                         sampling_params=sampling_params,
                         use_tqdm=False)
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
87

88
    print("Warming up...")
89
90
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)
91

92
    if args.profile:
93
94
        profile_dir = args.profile_result_dir
        if not profile_dir:
95
96
97
            profile_dir = Path(
                "."
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
98
        print(f"Profiling (results will be saved to '{profile_dir}')...")
99
        run_to_completion(profile_dir=profile_dir)
100
101
        return

102
103
    # Benchmark.
    latencies = []
104
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
105
        latencies.append(run_to_completion(profile_dir=None))
106
    latencies = np.array(latencies)
107
    percentages = [10, 25, 50, 75, 90, 99]
108
    percentiles = np.percentile(latencies, percentages)
109
    print(f'Avg latency: {np.mean(latencies)} seconds')
110
111
    for percentage, percentile in zip(percentages, percentiles):
        print(f'{percentage}% percentile latency: {percentile} seconds')
112

113
114
115
116
117
118
119
120
121
122
    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

123
124

if __name__ == '__main__':
125
    parser = FlexibleArgumentParser(
126
        description='Benchmark the latency of processing a single batch of '
127
        'requests till completion.')
128
    parser.add_argument('--model', type=str, default='facebook/opt-125m')
129
130
    parser.add_argument('--speculative-model', type=str, default=None)
    parser.add_argument('--num-speculative-tokens', type=int, default=None)
131
132
133
134
    parser.add_argument('--speculative-draft-tensor-parallel-size',
                        '-spec-draft-tp',
                        type=int,
                        default=None)
135
    parser.add_argument('--tokenizer', type=str, default=None)
136
137
    parser.add_argument('--quantization',
                        '-q',
138
                        choices=[*QUANTIZATION_METHODS, None],
139
                        default=None)
140
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
141
142
143
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
144
145
146
    parser.add_argument('--n',
                        type=int,
                        default=1,
147
                        help='Number of generated sequences per prompt.')
148
    parser.add_argument('--use-beam-search', action='store_true')
149
150
151
152
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=10,
                        help='Number of iterations to run for warmup.')
153
154
    parser.add_argument('--num-iters',
                        type=int,
155
                        default=30,
156
                        help='Number of iterations to run.')
157
158
    parser.add_argument('--trust-remote-code',
                        action='store_true',
159
                        help='trust remote code from huggingface')
160
161
162
163
164
165
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output). '
        'If None, will be derived from the model.')
166
167
168
169
170
171
172
173
174
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
175
176
177
    parser.add_argument('--enforce-eager',
                        action='store_true',
                        help='enforce eager mode and disable CUDA graph')
178
    parser.add_argument(
179
        '--kv-cache-dtype',
180
        type=str,
181
182
183
184
185
        choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
        default="auto",
        help='Data type for kv cache storage. If "auto", will use model '
        'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
        'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
186
187
188
189
190
191
192
193
194
195
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
196
197
198
199
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
200
201
202
203
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
204
205
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
206
207
208
209
210
    parser.add_argument("--device",
                        type=str,
                        default="auto",
                        choices=DEVICE_OPTIONS,
                        help='device type for vLLM execution')
211
212
213
214
215
216
    parser.add_argument('--block-size',
                        type=int,
                        default=16,
                        help='block size of key/value cache')
    parser.add_argument(
        '--enable-chunked-prefill',
217
        action='store_true',
218
219
        help='If True, the prefill requests can be chunked based on the '
        'max_num_batched_tokens')
220
221
222
    parser.add_argument("--enable-prefix-caching",
                        action='store_true',
                        help="Enable automatic prefix caching")
223
224
225
226
227
    parser.add_argument(
        "--ray-workers-use-nsight",
        action='store_true',
        help="If specified, use nsight to profile ray workers",
    )
228
229
230
231
232
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
233
234
235
236
237
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the latency results in JSON format.')
238
239
240
241
242
243
    parser.add_argument('--gpu-memory-utilization',
                        type=float,
                        default=0.9,
                        help='the fraction of GPU memory to be used for '
                        'the model executor, which can range from 0 to 1.'
                        'If unspecified, will use the default value of 0.9.')
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    parser.add_argument(
        '--load-format',
        type=str,
        default=EngineArgs.load_format,
        choices=[
            'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
            'bitsandbytes'
        ],
        help='The format of the model weights to load.\n\n'
        '* "auto" will try to load the weights in the safetensors format '
        'and fall back to the pytorch bin format if safetensors format '
        'is not available.\n'
        '* "pt" will load the weights in the pytorch bin format.\n'
        '* "safetensors" will load the weights in the safetensors format.\n'
        '* "npcache" will load the weights in pytorch format and store '
        'a numpy cache to speed up the loading.\n'
        '* "dummy" will initialize the weights with random values, '
        'which is mainly for profiling.\n'
        '* "tensorizer" will load the weights using tensorizer from '
        'CoreWeave. See the Tensorize vLLM Model script in the Examples'
        'section for more information.\n'
        '* "bitsandbytes" will load the weights using bitsandbytes '
        'quantization.\n')
267
268
269
270
271
272
273
    parser.add_argument(
        '--distributed-executor-backend',
        choices=['ray', 'mp'],
        default=None,
        help='Backend to use for distributed serving. When more than 1 GPU '
        'is used, will be automatically set to "ray" if installed '
        'or "mp" (multiprocessing) otherwise.')
274
275
276
277
278
    parser.add_argument(
        '--otlp-traces-endpoint',
        type=str,
        default=None,
        help='Target URL to which OpenTelemetry traces will be sent.')
279
280
    args = parser.parse_args()
    main(args)