"vllm/vscode:/vscode.git/clone" did not exist on "3f42b05fbc53e50813a1619f5fc770f17ac2a1b6"
benchmark_latency.py 7.93 KB
Newer Older
1
"""Benchmark the latency of processing a single batch of requests."""
2
3
import argparse
import time
4
5
from pathlib import Path
from typing import Optional
6
7
8

import numpy as np
import torch
9
from tqdm import tqdm
10

Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm import LLM, SamplingParams
12
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
13
14
15


def main(args: argparse.Namespace):
16
17
18
    print(args)

    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
19
    # the engine will automatically process the request in multiple batches.
20
21
22
23
24
25
26
27
    llm = LLM(model=args.model,
              tokenizer=args.tokenizer,
              quantization=args.quantization,
              tensor_parallel_size=args.tensor_parallel_size,
              trust_remote_code=args.trust_remote_code,
              dtype=args.dtype,
              enforce_eager=args.enforce_eager,
              kv_cache_dtype=args.kv_cache_dtype,
28
              quantization_param_path=args.quantization_param_path,
29
30
              device=args.device,
              ray_workers_use_nsight=args.ray_workers_use_nsight,
31
32
33
              enable_chunked_prefill=args.enable_chunked_prefill,
              download_dir=args.download_dir,
              block_size=args.block_size)
34

Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37
38
39
    sampling_params = SamplingParams(
        n=args.n,
        temperature=0.0 if args.use_beam_search else 1.0,
        top_p=1.0,
        use_beam_search=args.use_beam_search,
40
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
        max_tokens=args.output_len,
    )
43
    print(sampling_params)
44
45
46
47
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
    dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
48

49
50
51
52
53
54
55
56
57
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
58
59
60
61
62
63
64
65
66
67
68
69
                llm.generate(prompt_token_ids=dummy_prompt_token_ids,
                             sampling_params=sampling_params,
                             use_tqdm=False)
            print(p.key_averages())
        else:
            start_time = time.perf_counter()
            llm.generate(prompt_token_ids=dummy_prompt_token_ids,
                         sampling_params=sampling_params,
                         use_tqdm=False)
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
70

71
    print("Warming up...")
72
73
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)
74

75
    if args.profile:
76
77
        profile_dir = args.profile_result_dir
        if not profile_dir:
78
79
80
            profile_dir = Path(
                "."
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
81
        print(f"Profiling (results will be saved to '{profile_dir}')...")
82
        run_to_completion(profile_dir=profile_dir)
83
84
        return

85
86
    # Benchmark.
    latencies = []
87
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
88
        latencies.append(run_to_completion(profile_dir=None))
89
90
91
    latencies = np.array(latencies)
    percentages = [10, 25, 50, 75, 90]
    percentiles = np.percentile(latencies, percentages)
92
    print(f'Avg latency: {np.mean(latencies)} seconds')
93
94
    for percentage, percentile in zip(percentages, percentiles):
        print(f'{percentage}% percentile latency: {percentile} seconds')
95
96
97


if __name__ == '__main__':
98
    parser = argparse.ArgumentParser(
99
        description='Benchmark the latency of processing a single batch of '
100
        'requests till completion.')
101
    parser.add_argument('--model', type=str, default='facebook/opt-125m')
102
    parser.add_argument('--tokenizer', type=str, default=None)
103
104
    parser.add_argument('--quantization',
                        '-q',
105
                        choices=[*QUANTIZATION_METHODS, None],
106
                        default=None)
107
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
108
109
110
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
111
112
113
    parser.add_argument('--n',
                        type=int,
                        default=1,
114
                        help='Number of generated sequences per prompt.')
115
    parser.add_argument('--use-beam-search', action='store_true')
116
117
118
119
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=10,
                        help='Number of iterations to run for warmup.')
120
121
    parser.add_argument('--num-iters',
                        type=int,
122
                        default=30,
123
                        help='Number of iterations to run.')
124
125
    parser.add_argument('--trust-remote-code',
                        action='store_true',
126
                        help='trust remote code from huggingface')
127
128
129
130
131
132
133
134
135
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
136
137
138
    parser.add_argument('--enforce-eager',
                        action='store_true',
                        help='enforce eager mode and disable CUDA graph')
139
140
141
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
142
        choices=['auto', 'fp8'],
143
144
        default='auto',
        help=
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        'Data type for kv cache storage. If "auto", will use model data type. '
        'FP8_E5M2 (without scaling) is only supported on cuda version greater '
        'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
        'common inference criteria.')
    parser.add_argument(
        '--quantization-param-path',
        type=str,
        default=None,
        help='Path to the JSON file containing the KV cache scaling factors. '
        'This should generally be supplied, when KV cache dtype is FP8. '
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
        'instead supported for common inference criteria.')
159
160
161
162
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
163
164
165
166
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
167
168
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
169
170
171
172
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
173
174
        choices=["cuda", "cpu"],
        help='device type for vLLM execution, supporting CUDA and CPU.')
175
176
177
178
179
180
    parser.add_argument('--block-size',
                        type=int,
                        default=16,
                        help='block size of key/value cache')
    parser.add_argument(
        '--enable-chunked-prefill',
181
        action='store_true',
182
183
        help='If True, the prefill requests can be chunked based on the '
        'max_num_batched_tokens')
184
185
186
187
188
    parser.add_argument(
        "--ray-workers-use-nsight",
        action='store_true',
        help="If specified, use nsight to profile ray workers",
    )
189
190
191
192
193
    parser.add_argument('--download-dir',
                        type=str,
                        default=None,
                        help='directory to download and load the weights, '
                        'default to the default cache dir of huggingface')
194
195
    args = parser.parse_args()
    main(args)