benchmark_latency.py 8.83 KB
Newer Older
1
"""Benchmark the latency of processing a single batch of requests."""
2
import argparse
3
import json
4
import time
5
from pathlib import Path
6
from typing import List, Optional
7
8
9

import numpy as np
import torch
10
from tqdm import tqdm
11

Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm import LLM, SamplingParams
13
from vllm.inputs import PromptStrictInputs
14
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
15
16
17


def main(args: argparse.Namespace):
18
19
20
    print(args)

    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
21
    # the engine will automatically process the request in multiple batches.
22
    llm = LLM(model=args.model,
23
24
              speculative_model=args.speculative_model,
              num_speculative_tokens=args.num_speculative_tokens,
25
26
27
28
29
30
31
              tokenizer=args.tokenizer,
              quantization=args.quantization,
              tensor_parallel_size=args.tensor_parallel_size,
              trust_remote_code=args.trust_remote_code,
              dtype=args.dtype,
              enforce_eager=args.enforce_eager,
              kv_cache_dtype=args.kv_cache_dtype,
32
              quantization_param_path=args.quantization_param_path,
33
34
              device=args.device,
              ray_workers_use_nsight=args.ray_workers_use_nsight,
35
              use_v2_block_manager=args.use_v2_block_manager,
36
37
38
              enable_chunked_prefill=args.enable_chunked_prefill,
              download_dir=args.download_dir,
              block_size=args.block_size)
39

Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
43
44
    sampling_params = SamplingParams(
        n=args.n,
        temperature=0.0 if args.use_beam_search else 1.0,
        top_p=1.0,
        use_beam_search=args.use_beam_search,
45
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
        max_tokens=args.output_len,
    )
48
    print(sampling_params)
49
50
51
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
52
53
54
    dummy_inputs: List[PromptStrictInputs] = [{
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
55

56
57
58
59
60
61
62
63
64
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
65
                llm.generate(dummy_inputs,
66
67
68
69
70
                             sampling_params=sampling_params,
                             use_tqdm=False)
            print(p.key_averages())
        else:
            start_time = time.perf_counter()
71
            llm.generate(dummy_inputs,
72
73
74
75
76
                         sampling_params=sampling_params,
                         use_tqdm=False)
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
77

78
    print("Warming up...")
79
80
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)
81

82
    if args.profile:
83
84
        profile_dir = args.profile_result_dir
        if not profile_dir:
85
86
87
            profile_dir = Path(
                "."
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
88
        print(f"Profiling (results will be saved to '{profile_dir}')...")
89
        run_to_completion(profile_dir=profile_dir)
90
91
        return

92
93
    # Benchmark.
    latencies = []
94
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
95
        latencies.append(run_to_completion(profile_dir=None))
96
97
98
    latencies = np.array(latencies)
    percentages = [10, 25, 50, 75, 90]
    percentiles = np.percentile(latencies, percentages)
99
    print(f'Avg latency: {np.mean(latencies)} seconds')
100
101
    for percentage, percentile in zip(percentages, percentiles):
        print(f'{percentage}% percentile latency: {percentile} seconds')
102

103
104
105
106
107
108
109
110
111
112
    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

113
114

if __name__ == '__main__':
115
    parser = argparse.ArgumentParser(
116
        description='Benchmark the latency of processing a single batch of '
117
        'requests till completion.')
118
    parser.add_argument('--model', type=str, default='facebook/opt-125m')
119
120
    parser.add_argument('--speculative-model', type=str, default=None)
    parser.add_argument('--num-speculative-tokens', type=int, default=None)
121
    parser.add_argument('--tokenizer', type=str, default=None)
122
123
    parser.add_argument('--quantization',
                        '-q',
124
                        choices=[*QUANTIZATION_METHODS, None],
125
                        default=None)
126
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
127
128
129
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
130
131
132
    parser.add_argument('--n',
                        type=int,
                        default=1,
133
                        help='Number of generated sequences per prompt.')
134
    parser.add_argument('--use-beam-search', action='store_true')
135
136
137
138
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=10,
                        help='Number of iterations to run for warmup.')
139
140
    parser.add_argument('--num-iters',
                        type=int,
141
                        default=30,
142
                        help='Number of iterations to run.')
143
144
    parser.add_argument('--trust-remote-code',
                        action='store_true',
145
                        help='trust remote code from huggingface')
146
147
148
149
150
151
152
153
154
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
155
156
157
    parser.add_argument('--enforce-eager',
                        action='store_true',
                        help='enforce eager mode and disable CUDA graph')
158
    parser.add_argument(
159
        '--kv-cache-dtype',
160
        type=str,
161
162
163
164
165
        choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
        default="auto",
        help='Data type for kv cache storage. If "auto", will use model '
        'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
        'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
166
167
168
169
170
171
172
173
174
175
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
176
177
178
179
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
180
181
182
183
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
184
185
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
186
187
188
189
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
190
191
        choices=["cuda", "cpu"],
        help='device type for vLLM execution, supporting CUDA and CPU.')
192
193
194
195
196
197
    parser.add_argument('--block-size',
                        type=int,
                        default=16,
                        help='block size of key/value cache')
    parser.add_argument(
        '--enable-chunked-prefill',
198
        action='store_true',
199
200
        help='If True, the prefill requests can be chunked based on the '
        'max_num_batched_tokens')
201
    parser.add_argument('--use-v2-block-manager', action='store_true')
202
203
204
205
206
    parser.add_argument(
        "--ray-workers-use-nsight",
        action='store_true',
        help="If specified, use nsight to profile ray workers",
    )
207
208
209
210
211
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
212
213
214
215
216
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the latency results in JSON format.')
217
218
    args = parser.parse_args()
    main(args)