"vscode:/vscode.git/clone" did not exist on "2c08ff23c07f2f8d51da8e1783c5346dccc1fd12"
benchmark_latency.py 11.7 KB
Newer Older
1
"""Benchmark the latency of processing a single batch of requests."""
2
import argparse
3
import json
4
import time
5
from pathlib import Path
6
from typing import List, Optional
7
8
9

import numpy as np
import torch
10
from tqdm import tqdm
11

Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm import LLM, SamplingParams
13
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
14
from vllm.inputs import PromptType
15
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
16
from vllm.utils import FlexibleArgumentParser
17
18
19


def main(args: argparse.Namespace):
20
21
22
    print(args)

    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
23
    # the engine will automatically process the request in multiple batches.
24
25
26
27
    llm = LLM(
        model=args.model,
        speculative_model=args.speculative_model,
        num_speculative_tokens=args.num_speculative_tokens,
28
29
        speculative_draft_tensor_parallel_size=\
            args.speculative_draft_tensor_parallel_size,
30
31
32
33
34
        tokenizer=args.tokenizer,
        quantization=args.quantization,
        tensor_parallel_size=args.tensor_parallel_size,
        trust_remote_code=args.trust_remote_code,
        dtype=args.dtype,
35
        max_model_len=args.max_model_len,
36
37
38
39
40
41
42
43
44
45
46
47
48
        enforce_eager=args.enforce_eager,
        kv_cache_dtype=args.kv_cache_dtype,
        quantization_param_path=args.quantization_param_path,
        device=args.device,
        ray_workers_use_nsight=args.ray_workers_use_nsight,
        use_v2_block_manager=args.use_v2_block_manager,
        enable_chunked_prefill=args.enable_chunked_prefill,
        download_dir=args.download_dir,
        block_size=args.block_size,
        gpu_memory_utilization=args.gpu_memory_utilization,
        load_format=args.load_format,
        distributed_executor_backend=args.distributed_executor_backend,
        otlp_traces_endpoint=args.otlp_traces_endpoint,
49
        enable_prefix_caching=args.enable_prefix_caching,
50
    )
51

Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
    sampling_params = SamplingParams(
        n=args.n,
54
        temperature=1.0,
Woosuk Kwon's avatar
Woosuk Kwon committed
55
        top_p=1.0,
56
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
        max_tokens=args.output_len,
    )
59
    print(sampling_params)
60
61
62
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
63
    dummy_prompts: List[PromptType] = [{
64
65
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
66

67
68
69
70
71
72
73
74
75
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
76
                llm.generate(dummy_prompts,
77
78
79
80
81
                             sampling_params=sampling_params,
                             use_tqdm=False)
            print(p.key_averages())
        else:
            start_time = time.perf_counter()
82
            llm.generate(dummy_prompts,
83
84
85
86
87
                         sampling_params=sampling_params,
                         use_tqdm=False)
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
88

89
    print("Warming up...")
90
91
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)
92

93
    if args.profile:
94
95
        profile_dir = args.profile_result_dir
        if not profile_dir:
96
97
98
            profile_dir = Path(
                "."
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
99
        print(f"Profiling (results will be saved to '{profile_dir}')...")
100
        run_to_completion(profile_dir=profile_dir)
101
102
        return

103
104
    # Benchmark.
    latencies = []
105
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
106
        latencies.append(run_to_completion(profile_dir=None))
107
    latencies = np.array(latencies)
108
    percentages = [10, 25, 50, 75, 90, 99]
109
    percentiles = np.percentile(latencies, percentages)
110
    print(f'Avg latency: {np.mean(latencies)} seconds')
111
112
    for percentage, percentile in zip(percentages, percentiles):
        print(f'{percentage}% percentile latency: {percentile} seconds')
113

114
115
116
117
118
119
120
121
122
123
    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

124
125

if __name__ == '__main__':
126
    parser = FlexibleArgumentParser(
127
        description='Benchmark the latency of processing a single batch of '
128
        'requests till completion.')
129
    parser.add_argument('--model', type=str, default='facebook/opt-125m')
130
131
    parser.add_argument('--speculative-model', type=str, default=None)
    parser.add_argument('--num-speculative-tokens', type=int, default=None)
132
133
134
135
    parser.add_argument('--speculative-draft-tensor-parallel-size',
                        '-spec-draft-tp',
                        type=int,
                        default=None)
136
    parser.add_argument('--tokenizer', type=str, default=None)
137
138
    parser.add_argument('--quantization',
                        '-q',
139
                        choices=[*QUANTIZATION_METHODS, None],
140
                        default=None)
141
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
142
143
144
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
145
146
147
    parser.add_argument('--n',
                        type=int,
                        default=1,
148
                        help='Number of generated sequences per prompt.')
149
    parser.add_argument('--use-beam-search', action='store_true')
150
151
152
153
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=10,
                        help='Number of iterations to run for warmup.')
154
155
    parser.add_argument('--num-iters',
                        type=int,
156
                        default=30,
157
                        help='Number of iterations to run.')
158
159
    parser.add_argument('--trust-remote-code',
                        action='store_true',
160
                        help='trust remote code from huggingface')
161
162
163
164
165
166
    parser.add_argument(
        '--max-model-len',
        type=int,
        default=None,
        help='Maximum length of a sequence (including prompt and output). '
        'If None, will be derived from the model.')
167
168
169
170
171
172
173
174
175
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
176
177
178
    parser.add_argument('--enforce-eager',
                        action='store_true',
                        help='enforce eager mode and disable CUDA graph')
179
    parser.add_argument(
180
        '--kv-cache-dtype',
181
        type=str,
182
183
184
185
186
        choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
        default="auto",
        help='Data type for kv cache storage. If "auto", will use model '
        'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
        'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
187
188
189
190
191
192
193
194
195
196
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
197
198
199
200
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
201
202
203
204
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
205
206
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
207
208
209
210
211
    parser.add_argument("--device",
                        type=str,
                        default="auto",
                        choices=DEVICE_OPTIONS,
                        help='device type for vLLM execution')
212
213
214
215
216
217
    parser.add_argument('--block-size',
                        type=int,
                        default=16,
                        help='block size of key/value cache')
    parser.add_argument(
        '--enable-chunked-prefill',
218
        action='store_true',
219
220
        help='If True, the prefill requests can be chunked based on the '
        'max_num_batched_tokens')
221
222
223
    parser.add_argument("--enable-prefix-caching",
                        action='store_true',
                        help="Enable automatic prefix caching")
224
    parser.add_argument('--use-v2-block-manager', action='store_true')
225
226
227
228
229
    parser.add_argument(
        "--ray-workers-use-nsight",
        action='store_true',
        help="If specified, use nsight to profile ray workers",
    )
230
231
232
233
234
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
235
236
237
238
239
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the latency results in JSON format.')
240
241
242
243
244
245
    parser.add_argument('--gpu-memory-utilization',
                        type=float,
                        default=0.9,
                        help='the fraction of GPU memory to be used for '
                        'the model executor, which can range from 0 to 1.'
                        'If unspecified, will use the default value of 0.9.')
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    parser.add_argument(
        '--load-format',
        type=str,
        default=EngineArgs.load_format,
        choices=[
            'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
            'bitsandbytes'
        ],
        help='The format of the model weights to load.\n\n'
        '* "auto" will try to load the weights in the safetensors format '
        'and fall back to the pytorch bin format if safetensors format '
        'is not available.\n'
        '* "pt" will load the weights in the pytorch bin format.\n'
        '* "safetensors" will load the weights in the safetensors format.\n'
        '* "npcache" will load the weights in pytorch format and store '
        'a numpy cache to speed up the loading.\n'
        '* "dummy" will initialize the weights with random values, '
        'which is mainly for profiling.\n'
        '* "tensorizer" will load the weights using tensorizer from '
        'CoreWeave. See the Tensorize vLLM Model script in the Examples'
        'section for more information.\n'
        '* "bitsandbytes" will load the weights using bitsandbytes '
        'quantization.\n')
269
270
271
272
273
274
275
    parser.add_argument(
        '--distributed-executor-backend',
        choices=['ray', 'mp'],
        default=None,
        help='Backend to use for distributed serving. When more than 1 GPU '
        'is used, will be automatically set to "ray" if installed '
        'or "mp" (multiprocessing) otherwise.')
276
277
278
279
280
    parser.add_argument(
        '--otlp-traces-endpoint',
        type=str,
        default=None,
        help='Target URL to which OpenTelemetry traces will be sent.')
281
282
    args = parser.parse_args()
    main(args)