profiling.py 17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import inspect
import json
import os
import sys
from argparse import RawTextHelpFormatter
9
from collections.abc import Generator
10
from dataclasses import asdict, dataclass
11
from typing import Any, Optional, TypeAlias
12
13

import torch
14
import tqdm
15
16
17

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
18
from vllm.profiler.layerwise_profile import layerwise_profile
19
20
21
22
23
24
25
26
27
28
29
from vllm.utils import FlexibleArgumentParser

BATCH_SIZE_DEFAULT = 1
PROMPT_LEN_DEFAULT = 256


@dataclass
class ProfileContext:
    engine_args: EngineArgs
    prompt_len: int
    batch_size: int
30
31
32
33
34
35
36
37

    # The profiler can run in 2 modes,
    # 1. Run profiler for user specified num_steps
    num_steps: Optional[int] = None
    # 2. Run profiler until all requests complete
    complete_num_requests_per_step: Optional[int] = None

    save_chrome_traces_folder: Optional[str] = None
38
39
40
41
42
43
44
45
46


def get_dtype(dtype: str):
    if dtype == "torch.float":
        return torch.float
    else:
        return dtype


47
OutputLen_NumReqs_Map: TypeAlias = dict[int, int]
48
49
50
51
52


def compute_request_output_lengths(
    batch_size: int, step_requests: list[int]
) -> OutputLen_NumReqs_Map:
53
54
55
56
57
    """
    Given the number of requests, batch_size, and the number of requests
    that each engine-step should process, step_requests, determine the
    output lengths of the requests such that step_request is honoured.

58
    Example:
59
60
61
62
63
64
65
66
67
68
69
70
    if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
    then return,
    {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
    32 requests should have output length 2,
    32 requests should have output length 3,
    32 requests should have output length 4,
    31 requests should have output length 5,
    1 request should have output length 6.

    Args:
        batch_size (int): Number of requests submitted for profile. This is
            args.batch_size.
71
        step_requests (list[int]): step_requests[i] is the number of requests
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            that the ith engine step should process.

    Returns:
        OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
            number of requests required to have that output-length as values.
    """
    ol_nr: OutputLen_NumReqs_Map = {}

    # Number of request that are assigned an output-length
    num_reqs_assigned: int = 0
    num_steps: int = len(step_requests)

    # sanity check. The first step (prefill-step), must process all requests.
    assert step_requests[0] == batch_size

    # Begin assignments from the last step.
    output_length: int = num_steps
    for num_requests_at_step in reversed(step_requests):
        if num_reqs_assigned == batch_size:
            break

        assert num_reqs_assigned < batch_size

        # Remove the number of requests that have been determined
        # to participate in this step and beyond.
        num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
        assert num_reqs_unassigned_at_step >= 0

        if num_reqs_unassigned_at_step > 0:
            ol_nr[output_length] = num_reqs_unassigned_at_step
            num_reqs_assigned += num_reqs_unassigned_at_step

        output_length -= 1

    # sanity checks.
107
108
109
110
111
    assert sum(ol_nr.values()) == batch_size, (
        "Number of requests in output-length assignment does not match "
        f"batch-size.\n batch size {batch_size} - "
        f"step requests {step_requests} - assignments {ol_nr}"
    )
112
113
114

    # Check that the output-length is in [1, num-steps]. Output length must be
    # at least 1 as all requests must participate in the prefill-step.
115
116
117
118
119
    assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), (
        "Output lengths of requests should be in range "
        f"[1, num-engine-steps].\n batch size {batch_size} - "
        f"step requests {step_requests} - assignments {ol_nr}"
    )
120
121
122
123

    return ol_nr


124
def determine_requests_per_step(context: ProfileContext) -> list[int]:
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    """
    Determine number of requests each engine step should process.
    If context.num_steps is set, then all engine steps process the
    same number of requests and the output list is of length
    context.num_steps.

    If context.complete_num_requests_per_step is set, then each decode step
    processes fewer and fewer requests until there are no requests to process.
    In this case, the output list is as big as the number of steps
    required to process all requests.

    Args:
        context: ProfileContext object.

    Returns:
140
        list[int]: Number of requests to process for all engine-steps.
141
142
143
144
145
146
147
148
         output[i], contains the number of requests that the ith step
         should process.
    """
    if context.num_steps:
        # All requests must run until num_engine_steps. This implies
        # that their output lengths must be equal to num_engine_steps.
        return [context.batch_size] * context.num_steps

149
150
151
152
153
154
155
    assert (
        context.complete_num_requests_per_step
        and context.complete_num_requests_per_step > 0
    ), (
        f"Expected a positive complete_num_requests_per_step argument."
        f"Instead got {context.complete_num_requests_per_step}"
    )
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    # We start dropping after the first decode step.
    step_requests = [
        context.batch_size,  # prefill
        context.batch_size,  # decode
    ]

    num_running_requests = context.batch_size
    num_running_requests -= context.complete_num_requests_per_step
    while num_running_requests > 0:
        step_requests.append(num_running_requests)
        num_running_requests -= context.complete_num_requests_per_step

    if step_requests[-1] != 1:
        # have 1 request running at the last step. This is often
        # useful
        step_requests.append(1)

    return step_requests


177
178
179
def run_profile(
    context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]
):
180
181
182
183
    print("Run profile with:")
    for key, value in asdict(context).items():
        print(f"  {key} = {value}")

184
    requests_per_step: list[int] = determine_requests_per_step(context)
185
186

    ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
187
188
        context.batch_size, requests_per_step
    )
189
190
191
192
193

    num_steps_to_profile: int = len(requests_per_step)
    max_output_len: int = max(ol_nr.keys())
    assert max_output_len >= 1

194
    # Create sampling params
195
196
197
198
199
    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        # max_tokens is set on a per-request basis.
        max_tokens=None,
200
201
        ignore_eos=True,
    )
202
203
204
205
206
207

    # Create LLM
    llm = LLM(**asdict(context.engine_args))
    batch_size = context.batch_size
    prompt_len = context.prompt_len

208
    scheduler_config = llm.llm_engine.vllm_config.scheduler_config
209
210
211
212
213
    max_model_len = llm.llm_engine.model_config.max_model_len
    max_num_batched_tokens = scheduler_config.max_num_batched_tokens
    max_num_seqs = scheduler_config.max_num_seqs

    if batch_size * prompt_len > max_num_batched_tokens:
214
215
216
217
218
219
220
221
        print(
            f"ERROR: chosen batch_size * prompt_len "
            f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is  "
            f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
            f"and therefore cannot be run in a single profile step, please "
            f"choose a smaller batch size or prompt length, or increase "
            f"--max-num-batched-tokens"
        )
222
        sys.exit(-1)
223
    if batch_size > max_num_seqs:
224
225
226
        print(
            f"ERROR: chosen batch_size ({batch_size}) is larger than "
            f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
227
228
            f"single profile step, please choose a smaller batch size"
        )
229
        sys.exit(-1)
230
231
232
233
    print(
        "llm.llm_engine.model_config.max_model_len: ",
        llm.llm_engine.model_config.max_model_len,
    )
234
    if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
235
236
237
238
239
240
241
        print(
            f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
            f"{max_output_len} = {prompt_len + max_output_len}) is larger "
            f"than the model's max_model_len ({max_model_len}), please "
            f"choose a smaller prompt_len or max_output_len, or increase "
            f"--max-model-len"
        )
242
243
244
        sys.exit(-1)

    def add_requests():
245
246
247
248
249
250
        def get_output_len_generator() -> Generator[int, Any, Any]:
            for output_len, num_reqs in ol_nr.items():
                for _ in range(num_reqs):
                    yield output_len

        output_len_generator = get_output_len_generator()
251
        for i in range(batch_size):
252
253
254
            sampling_params.max_tokens = next(output_len_generator)
            assert isinstance(sampling_params.max_tokens, int)

255
256
257
            prompt_token_ids = torch.randint(
                llm.get_tokenizer().vocab_size, size=(prompt_len,)
            ).tolist()
258
259
260

            llm.llm_engine.add_request(
                request_id=f"seq{i}",
261
262
263
                prompt={"prompt_token_ids": prompt_token_ids},
                params=sampling_params,
            )
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

    def abort_requests():
        for i in range(batch_size):
            llm.llm_engine.abort_request(f"seq{i}")

    # Warm up run
    print("Warm up run ...")
    add_requests()
    llm.llm_engine.step()  # Prefill
    llm.llm_engine.step()  # Decode
    abort_requests()

    print("Profile run ...")
    add_requests()

    with layerwise_profile() as prefill_prof:
        llm.llm_engine.step()  # First step is prefill

    decode_profs = []
283
    for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
284
285
        num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups()
        with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof:
286
287
288
289
290
291
292
293
294
            llm.llm_engine.step()
        decode_profs.append(decode_prof)

    decode_results_list = [prof.results for prof in decode_profs]
    prefill_results = prefill_prof.results
    has_decode = len(decode_results_list) > 0

    LINE_WIDTH = 80
    print("=" * LINE_WIDTH)
295
    print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})")
296
297
298
299
300
301
302
    print("=" * LINE_WIDTH)
    print()
    prefill_results.print_model_table()

    if has_decode:
        print()
        print("=" * LINE_WIDTH)
303
304
305
306
        print(
            f"= First Decode Step Model Table "
            f"(prompt_len={prompt_len}, batch_size={batch_size})"
        )
307
308
309
310
311
312
        print("=" * LINE_WIDTH)
        print()
        decode_results_list[0].print_model_table()

    print()
    print("=" * LINE_WIDTH)
313
    print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})")
314
315
316
317
318
319
320
    print("=" * LINE_WIDTH)
    print()
    prefill_results.print_summary_table()

    if has_decode:
        print()
        print("=" * LINE_WIDTH)
321
322
323
324
        print(
            f"= First Decode Step Summary Table "
            f"(prompt_len={prompt_len}, batch_size={batch_size})"
        )
325
326
327
328
329
        print("=" * LINE_WIDTH)
        print()
        decode_results_list[0].print_summary_table()

    if csv_output:
330
331
332
        csv_filename_base = (
            csv_output[:-4] if csv_output.endswith(".csv") else csv_output
        )
333
        prefill_results.export_model_stats_table_csv(
334
335
            csv_filename_base + "_prefill_model_table.csv"
        )
336
        prefill_results.export_summary_stats_table_csv(
337
338
            csv_filename_base + "_prefill_summary_table.csv"
        )
339
340

        if has_decode:
341
342
343
            decode_results_list[0].export_model_stats_table_csv(
                csv_filename_base + "_decode_model_table.csv"
            )
344
            decode_results_list[0].export_summary_stats_table_csv(
345
346
                csv_filename_base + "_decode_summary_table.csv"
            )
347
348
349
350
351
352
353
354
355
356
357
358
359

    if json_output:
        cuda_devices = [
            torch.cuda.get_device_properties(dev_idx)
            for dev_idx in range(torch.cuda.device_count())
        ]

        json_dict = {
            "context": {
                "python_version": f"{sys.version}",
                "torch_version": f"{torch.__version__}",
                "torch_cuda_version": f"{torch.version.cuda}",
                "cuda_devices": f"{cuda_devices}",
360
                **asdict(context),
361
362
363
364
365
366
367
368
            },
            "prefill": prefill_results.convert_stats_to_dict(),
        }

        if has_decode:
            for idx, dr in enumerate(decode_results_list):
                json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()

369
        # Add .json to json_output filename if it doesn't exist already.
370
371
372
        json_output_file = (
            json_output if json_output.endswith(".json") else json_output + ".json"
        )
373
        with open(json_output_file, "w+") as f:
374
375
376
377
378
379
            json.dump(json_dict, f, indent=2)
        pass

    if context.save_chrome_traces_folder is not None:
        os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
        prefill_prof.profiler.export_chrome_trace(
380
381
            context.save_chrome_traces_folder + "/prefill.json"
        )
382
383
        for idx, decode_prof in enumerate(decode_profs):
            decode_prof.profiler.export_chrome_trace(
384
385
386
387
388
389
                context.save_chrome_traces_folder + f"/decode_{idx + 1}.json"
            )
        print(
            "Traces saved as prefill.json and decode_1.json, etc."
            f" in folder {context.save_chrome_traces_folder}"
        )
390
391


392
def parse_args():
393
394
    parser = FlexibleArgumentParser(
        description="""
395
396
397
398
Profile a model

    example:
    ```
399
    python examples/offline_inference/profiling.py \\
400
401
        --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
        --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
402
        --enforce-eager run_num_steps -n 2
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    ```

    then you can use various tools to analyze the json output
    terminal ascii tables:
        ```
        python tools/profiler/print_layerwise_table.py \\
            --json-trace Llama31-8b-FP8.json --phase prefill --table summary
        ```
    or create matplotlib stacked bar charts:
        ```
        python tools/profiler/visualize_layerwise_profile.py \\
            --json-trace Llama31-8b-FP8.json \\
            --output-directory profile_breakdown --plot-metric pct_cuda_time
        ```
""",
418
419
        formatter_class=RawTextHelpFormatter,
    )
420
421
422
423
424
425
426
427
    parser.add_argument(
        "--csv",
        type=str,
        default=None,
        help="Export the results as multiple csv file. This should be the root "
        "filename, will create <filename>_prefill_model_table.csv, "
        "<filename>_prefill_summary_table.csv, "
        "<filename>_decode_model_table.csv, and "
428
429
        "<filename>_decode_summary_table.csv",
    )
430
431
432
433
    parser.add_argument(
        "--json",
        type=str,
        default=None,
434
435
436
437
438
439
440
441
442
        help="Export the results as a json file. This should be the filename",
    )
    parser.add_argument(
        "--save-chrome-traces-folder",
        type=str,
        help="Save chrome traces for the prefill and decode "
        "will save traces as prefill.json and decode_1.json, "
        "etc. inside this folder",
    )
443
444
445
446
447
    parser.add_argument(
        "--prompt-len",
        type=int,
        default=PROMPT_LEN_DEFAULT,
        help=f"Length of the random prompt to use when profiling, all batched "
448
449
450
451
452
453
454
455
456
        f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=BATCH_SIZE_DEFAULT,
        help=f"Number of requests to run as a single batch, "
        f"default={BATCH_SIZE_DEFAULT}",
    )
457
458
459
460

    subparsers = parser.add_subparsers(dest="cmd")

    run_num_steps_parser = subparsers.add_parser(
461
462
        "run_num_steps", help="This variation profiles n engine.step() invocations."
    )
463
    run_num_steps_parser.add_argument(
464
465
        "-n",
        "--num-steps",
466
        type=int,
467
468
469
470
        help="Number of engine steps to profile.\n"
        "Setting it to 1, profiles only the prefill step.\n"
        "Setting it to 2, profiles the prefill and first decode step\n"
        "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
471
472
        "and so on ...",
    )
473
474
475
476

    run_to_completion_parser = subparsers.add_parser(
        "run_to_completion",
        help="This variation profiles all the engine.step() invocations"
477
478
        "until the engine exhausts all submitted requests.",
    )
479
    run_to_completion_parser.add_argument(
480
481
        "-n",
        "--complete-num-requests-per-step",
482
        type=int,
483
        help="Complete complete_num_requests_per_step requests every decode step."
484
485
486
487
        "For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
        "the profiler is run for 6 engine steps, with the steps processing, "
        "128, 128, 96, 64, 32, 1 requests respectively.\n"
        "Note that we tack-on a one-request step at the end as it is often "
488
489
        "useful.",
    )
490
491
492

    EngineArgs.add_cli_args(parser)

493
494
495
496
    return parser.parse_args()


def main(args):
497
498
499
500
501
502
    context = ProfileContext(
        engine_args=EngineArgs.from_cli_args(args),
        **{
            k: v
            for k, v in vars(args).items()
            if k in inspect.signature(ProfileContext).parameters
503
504
        },
    )
505
    run_profile(context, csv_output=args.csv, json_output=args.json)
506
507
508
509
510


if __name__ == "__main__":
    args = parse_args()
    main(args)