benchmark_latency.py 4.98 KB
Newer Older
1
"""Benchmark the latency of processing a single batch of requests."""
2
import argparse
3
import dataclasses
4
import json
5
import time
6
from pathlib import Path
7
from typing import List, Optional
8
9
10

import numpy as np
import torch
11
from tqdm import tqdm
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm import LLM, SamplingParams
14
from vllm.engine.arg_utils import EngineArgs
15
from vllm.inputs import PromptType
16
from vllm.utils import FlexibleArgumentParser
17
18
19


def main(args: argparse.Namespace):
20
21
    print(args)

22
23
    engine_args = EngineArgs.from_cli_args(args)

24
    # NOTE(woosuk): If the request cannot be processed in a single batch,
Zhuohan Li's avatar
Zhuohan Li committed
25
    # the engine will automatically process the request in multiple batches.
26
    llm = LLM(**dataclasses.asdict(engine_args))
27

Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
    sampling_params = SamplingParams(
        n=args.n,
30
        temperature=1.0,
Woosuk Kwon's avatar
Woosuk Kwon committed
31
        top_p=1.0,
32
        ignore_eos=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
        max_tokens=args.output_len,
    )
35
    print(sampling_params)
36
37
38
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
39
    dummy_prompts: List[PromptType] = [{
40
41
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]
42

43
44
45
46
47
48
49
50
51
    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
52
                llm.generate(dummy_prompts,
53
54
55
56
57
                             sampling_params=sampling_params,
                             use_tqdm=False)
            print(p.key_averages())
        else:
            start_time = time.perf_counter()
58
            llm.generate(dummy_prompts,
59
60
61
62
63
                         sampling_params=sampling_params,
                         use_tqdm=False)
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency
64

65
    print("Warming up...")
66
67
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)
68

69
    if args.profile:
70
71
        profile_dir = args.profile_result_dir
        if not profile_dir:
72
73
74
            profile_dir = Path(
                "."
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
75
        print(f"Profiling (results will be saved to '{profile_dir}')...")
76
        run_to_completion(profile_dir=profile_dir)
77
78
        return

79
80
    # Benchmark.
    latencies = []
81
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
82
        latencies.append(run_to_completion(profile_dir=None))
83
    latencies = np.array(latencies)
84
    percentages = [10, 25, 50, 75, 90, 99]
85
    percentiles = np.percentile(latencies, percentages)
86
    print(f'Avg latency: {np.mean(latencies)} seconds')
87
88
    for percentage, percentile in zip(percentages, percentiles):
        print(f'{percentage}% percentile latency: {percentile} seconds')
89

90
91
92
93
94
95
96
97
98
99
    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

100
101

if __name__ == '__main__':
102
    parser = FlexibleArgumentParser(
103
        description='Benchmark the latency of processing a single batch of '
104
        'requests till completion.')
105
106
107
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
108
109
110
    parser.add_argument('--n',
                        type=int,
                        default=1,
111
                        help='Number of generated sequences per prompt.')
112
    parser.add_argument('--use-beam-search', action='store_true')
113
114
115
116
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=10,
                        help='Number of iterations to run for warmup.')
117
118
    parser.add_argument('--num-iters',
                        type=int,
119
                        default=30,
120
                        help='Number of iterations to run.')
121
122
123
124
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
125
126
127
128
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
129
130
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
131
132
133
134
135
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the latency results in JSON format.')
136
137

    parser = EngineArgs.add_cli_args(parser)
138
139
    args = parser.parse_args()
    main(args)