benchmark_latency.py 5.75 KB
Newer Older
1
"""Benchmark the latency of processing a single batch of requests."""
2
3
import argparse
import time
4
5
from pathlib import Path
from typing import Optional
6
7
8

import numpy as np
import torch
9
from tqdm import tqdm
10

Woosuk Kwon's avatar
Woosuk Kwon committed
11
from vllm import LLM, SamplingParams
12
13
14


def main(args: argparse.Namespace):
15
16
17
    print(args)

    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
18
    # the engine will automatically process the request in multiple batches.
19
20
    llm = LLM(
        model=args.model,
21
        tokenizer=args.tokenizer,
22
        quantization=args.quantization,
23
        tensor_parallel_size=args.tensor_parallel_size,
24
        trust_remote_code=args.trust_remote_code,
25
        dtype=args.dtype,
26
        enforce_eager=args.enforce_eager,
27
        kv_cache_dtype=args.kv_cache_dtype,
28
        device=args.device,
29
        ray_workers_use_nsight=args.ray_workers_use_nsight,
30
    )
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
35
36
    sampling_params = SamplingParams(
        n=args.n,
        temperature=0.0 if args.use_beam_search else 1.0,
        top_p=1.0,
        use_beam_search=args.use_beam_search,
37
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
        max_tokens=args.output_len,
    )
40
    print(sampling_params)
41
42
43
44
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
    dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
45

46
47
48
49
50
51
52
53
54
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
55
56
57
58
59
60
61
62
63
64
65
66
                llm.generate(prompt_token_ids=dummy_prompt_token_ids,
                             sampling_params=sampling_params,
                             use_tqdm=False)
            print(p.key_averages())
        else:
            start_time = time.perf_counter()
            llm.generate(prompt_token_ids=dummy_prompt_token_ids,
                         sampling_params=sampling_params,
                         use_tqdm=False)
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
67

68
    print("Warming up...")
69
    run_to_completion(profile_dir=None)
70

71
    if args.profile:
72
73
        profile_dir = args.profile_result_dir
        if not profile_dir:
74
75
76
            profile_dir = Path(
                "."
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
77
        print(f"Profiling (results will be saved to '{profile_dir}')...")
78
        run_to_completion(profile_dir=profile_dir)
79
80
        return

81
82
    # Benchmark.
    latencies = []
83
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
84
        latencies.append(run_to_completion(profile_dir=None))
85
86
87
88
    print(f'Avg latency: {np.mean(latencies)} seconds')


if __name__ == '__main__':
89
    parser = argparse.ArgumentParser(
90
        description='Benchmark the latency of processing a single batch of '
91
        'requests till completion.')
92
    parser.add_argument('--model', type=str, default='facebook/opt-125m')
93
    parser.add_argument('--tokenizer', type=str, default=None)
94
95
    parser.add_argument('--quantization',
                        '-q',
CHU Tianxiang's avatar
CHU Tianxiang committed
96
                        choices=['awq', 'gptq', 'squeezellm', None],
97
                        default=None)
98
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
99
100
101
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
102
103
104
    parser.add_argument('--n',
                        type=int,
                        default=1,
105
                        help='Number of generated sequences per prompt.')
106
    parser.add_argument('--use-beam-search', action='store_true')
107
108
109
    parser.add_argument('--num-iters',
                        type=int,
                        default=3,
110
                        help='Number of iterations to run.')
111
112
    parser.add_argument('--trust-remote-code',
                        action='store_true',
113
                        help='trust remote code from huggingface')
114
115
116
117
118
119
120
121
122
    parser.add_argument(
        '--dtype',
        type=str,
        default='auto',
        choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
        help='data type for model weights and activations. '
        'The "auto" option will use FP16 precision '
        'for FP32 and FP16 models, and BF16 precision '
        'for BF16 models.')
123
124
125
    parser.add_argument('--enforce-eager',
                        action='store_true',
                        help='enforce eager mode and disable CUDA graph')
126
127
128
129
130
131
132
    parser.add_argument(
        "--kv-cache-dtype",
        type=str,
        choices=['auto', 'fp8_e5m2'],
        default='auto',
        help=
        'Data type for kv cache storage. If "auto", will use model data type.')
133
134
135
136
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
137
138
139
140
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
141
142
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
143
144
145
146
147
148
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda"],
        help='device type for vLLM execution, supporting CUDA only currently.')
149
150
151
152
153
    parser.add_argument(
        "--ray-workers-use-nsight",
        action='store_true',
        help="If specified, use nsight to profile ray workers",
    )
154
155
    args = parser.parse_args()
    main(args)