benchmark_latency.py 5.32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Benchmark the latency of processing a single batch of requests."""
3
import argparse
4
import dataclasses
5
import json
6
import time
7
from pathlib import Path
8
from typing import List, Optional
9
10
11

import numpy as np
import torch
12
from tqdm import tqdm
13

Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm import LLM, SamplingParams
15
from vllm.engine.arg_utils import EngineArgs
16
from vllm.inputs import PromptType
17
from vllm.sampling_params import BeamSearchParams
18
from vllm.utils import FlexibleArgumentParser
19
20
21


def main(args: argparse.Namespace):
22
23
    print(args)

24
25
    engine_args = EngineArgs.from_cli_args(args)

26
    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
27
    # the engine will automatically process the request in multiple batches.
28
    llm = LLM(**dataclasses.asdict(engine_args))
29

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
    sampling_params = SamplingParams(
        n=args.n,
32
        temperature=1.0,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
        top_p=1.0,
34
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
        max_tokens=args.output_len,
    )
37
    print(sampling_params)
38
39
40
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
41
    dummy_prompts: List[PromptType] = [{
42
43
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
    def llm_generate():
        if not args.use_beam_search:
            llm.generate(dummy_prompts,
                         sampling_params=sampling_params,
                         use_tqdm=False)
        else:
            llm.beam_search(
                dummy_prompts,
                BeamSearchParams(
                    beam_width=args.n,
                    max_tokens=args.output_len,
                    ignore_eos=True,
                ))

59
60
61
62
63
64
65
66
67
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
68
                llm_generate()
69
            print(p.key_averages().table(sort_by="self_cuda_time_total"))
70
71
        else:
            start_time = time.perf_counter()
72
            llm_generate()
73
74
75
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
76

77
    print("Warming up...")
78
79
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)
80

81
    if args.profile:
82
83
        profile_dir = args.profile_result_dir
        if not profile_dir:
84
85
86
            profile_dir = Path(
                "."
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
87
        print(f"Profiling (results will be saved to '{profile_dir}')...")
88
        run_to_completion(profile_dir=profile_dir)
89
90
        return

91
92
    # Benchmark.
    latencies = []
93
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
94
        latencies.append(run_to_completion(profile_dir=None))
95
    latencies = np.array(latencies)
96
    percentages = [10, 25, 50, 75, 90, 99]
97
    percentiles = np.percentile(latencies, percentages)
98
    print(f'Avg latency: {np.mean(latencies)} seconds')
99
100
    for percentage, percentile in zip(percentages, percentiles):
        print(f'{percentage}% percentile latency: {percentile} seconds')
101

102
103
104
105
106
107
108
109
110
111
    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

112
113

if __name__ == '__main__':
114
    parser = FlexibleArgumentParser(
115
        description='Benchmark the latency of processing a single batch of '
116
        'requests till completion.')
117
118
119
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
120
121
122
    parser.add_argument('--n',
                        type=int,
                        default=1,
123
                        help='Number of generated sequences per prompt.')
124
    parser.add_argument('--use-beam-search', action='store_true')
125
126
127
128
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=10,
                        help='Number of iterations to run for warmup.')
129
130
    parser.add_argument('--num-iters',
                        type=int,
131
                        default=30,
132
                        help='Number of iterations to run.')
133
134
135
136
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
137
138
139
140
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
141
142
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
143
144
145
146
147
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the latency results in JSON format.')
148
149

    parser = EngineArgs.add_cli_args(parser)
150
151
    args = parser.parse_args()
    main(args)