offline_profile.py 16.8 KB
Newer Older
1
2
3
4
5
6
import inspect
import json
import os
import sys
from argparse import RawTextHelpFormatter
from dataclasses import asdict, dataclass
7
from typing import Any, Dict, Generator, List, Optional, TypeAlias
8
9

import torch
10
import tqdm
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.profiler import layerwise_profile
from vllm.utils import FlexibleArgumentParser

BATCH_SIZE_DEFAULT = 1
PROMPT_LEN_DEFAULT = 256


@dataclass
class ProfileContext:
    engine_args: EngineArgs
    prompt_len: int
    batch_size: int
26
27
28
29
30
31
32
33

    # The profiler can run in 2 modes,
    # 1. Run profiler for user specified num_steps
    num_steps: Optional[int] = None
    # 2. Run profiler until all requests complete
    complete_num_requests_per_step: Optional[int] = None

    save_chrome_traces_folder: Optional[str] = None
34
35
36
37
38
39
40
41
42


def get_dtype(dtype: str):
    if dtype == "torch.float":
        return torch.float
    else:
        return dtype


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
      -> OutputLen_NumReqs_Map:
    """
    Given the number of requests, batch_size, and the number of requests
    that each engine-step should process, step_requests, determine the
    output lengths of the requests such that step_request is honoured.

    Example: 
    if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
    then return,
    {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
    32 requests should have output length 2,
    32 requests should have output length 3,
    32 requests should have output length 4,
    31 requests should have output length 5,
    1 request should have output length 6.

    Args:
        batch_size (int): Number of requests submitted for profile. This is
            args.batch_size.
        step_requests (List[int]): step_requests[i] is the number of requests
            that the ith engine step should process.

    Returns:
        OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
            number of requests required to have that output-length as values.
    """
    ol_nr: OutputLen_NumReqs_Map = {}

    # Number of request that are assigned an output-length
    num_reqs_assigned: int = 0
    num_steps: int = len(step_requests)

    # sanity check. The first step (prefill-step), must process all requests.
    assert step_requests[0] == batch_size

    # Begin assignments from the last step.
    output_length: int = num_steps
    for num_requests_at_step in reversed(step_requests):
        if num_reqs_assigned == batch_size:
            break

        assert num_reqs_assigned < batch_size

        # Remove the number of requests that have been determined
        # to participate in this step and beyond.
        num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
        assert num_reqs_unassigned_at_step >= 0

        if num_reqs_unassigned_at_step > 0:
            ol_nr[output_length] = num_reqs_unassigned_at_step
            num_reqs_assigned += num_reqs_unassigned_at_step

        output_length -= 1

    # sanity checks.
    assert sum(ol_nr.values()) == batch_size, \
            ("Number of requests in output-length assignment does not match "
             f"batch-size.\n batch size {batch_size} - "
             f"step requests {step_requests} - assignments {ol_nr}")

    # Check that the output-length is in [1, num-steps]. Output length must be
    # at least 1 as all requests must participate in the prefill-step.
    assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
            ("Output lengths of requests should be in range "
             f"[1, num-engine-steps].\n batch size {batch_size} - "
             f"step requests {step_requests} - assignments {ol_nr}")

    return ol_nr


def determine_requests_per_step(context: ProfileContext) -> List[int]:
    """
    Determine number of requests each engine step should process.
    If context.num_steps is set, then all engine steps process the
    same number of requests and the output list is of length
    context.num_steps.

    If context.complete_num_requests_per_step is set, then each decode step
    processes fewer and fewer requests until there are no requests to process.
    In this case, the output list is as big as the number of steps
    required to process all requests.

    Args:
        context: ProfileContext object.

    Returns:
        List[int]: Number of requests to process for all engine-steps. 
         output[i], contains the number of requests that the ith step
         should process.
    """
    if context.num_steps:
        # All requests must run until num_engine_steps. This implies
        # that their output lengths must be equal to num_engine_steps.
        return [context.batch_size] * context.num_steps

    assert context.complete_num_requests_per_step and \
                context.complete_num_requests_per_step > 0, \
        (f"Expected a positive complete_num_requests_per_step argument."
         f"Instead got {context.complete_num_requests_per_step}")

    # We start dropping after the first decode step.
    step_requests = [
        context.batch_size,  # prefill
        context.batch_size,  # decode
    ]

    num_running_requests = context.batch_size
    num_running_requests -= context.complete_num_requests_per_step
    while num_running_requests > 0:
        step_requests.append(num_running_requests)
        num_running_requests -= context.complete_num_requests_per_step

    if step_requests[-1] != 1:
        # have 1 request running at the last step. This is often
        # useful
        step_requests.append(1)

    return step_requests


165
166
167
168
169
170
def run_profile(context: ProfileContext, csv_output: Optional[str],
                json_output: Optional[str]):
    print("Run profile with:")
    for key, value in asdict(context).items():
        print(f"  {key} = {value}")

171
172
173
174
175
176
177
178
179
    requests_per_step: List[int] = determine_requests_per_step(context)

    ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
        context.batch_size, requests_per_step)

    num_steps_to_profile: int = len(requests_per_step)
    max_output_len: int = max(ol_nr.keys())
    assert max_output_len >= 1

180
    # Create sampling params
181
182
183
184
185
186
    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
        # max_tokens is set on a per-request basis.
        max_tokens=None,
        ignore_eos=True)
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    # Create LLM
    llm = LLM(**asdict(context.engine_args))
    batch_size = context.batch_size
    prompt_len = context.prompt_len

    scheduler_config = llm.llm_engine.scheduler_config
    max_model_len = llm.llm_engine.model_config.max_model_len
    max_num_batched_tokens = scheduler_config.max_num_batched_tokens
    max_num_seqs = scheduler_config.max_num_seqs

    if batch_size * prompt_len > max_num_batched_tokens:
        print(f"ERROR: chosen batch_size * prompt_len "
              f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is  "
              f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
              f"and therefore cannot be run in a single profile step, please "
              f"choose a smaller batch size or prompt length, or increase "
              f"--max-num-batched-tokens")
        sys.exit(-1)
206
    if batch_size > max_num_seqs:
207
208
209
210
211
212
213
        print(
            f"ERROR: chosen batch_size ({batch_size}) is larger than "
            f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
            f"single profile step, please choose a smaller batch size")
        sys.exit(-1)
    print("llm.llm_engine.model_config.max_model_len: ",
          llm.llm_engine.model_config.max_model_len)
214
215
216
217
218
219
    if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
        print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
              f"{max_output_len} = {prompt_len + max_output_len}) is larger "
              f"than the model's max_model_len ({max_model_len}), please "
              f"choose a smaller prompt_len or max_output_len, or increase "
              f"--max-model-len")
220
221
222
        sys.exit(-1)

    def add_requests():
223
224
225
226
227
228
229

        def get_output_len_generator() -> Generator[int, Any, Any]:
            for output_len, num_reqs in ol_nr.items():
                for _ in range(num_reqs):
                    yield output_len

        output_len_generator = get_output_len_generator()
230
        for i in range(batch_size):
231
232
233
            sampling_params.max_tokens = next(output_len_generator)
            assert isinstance(sampling_params.max_tokens, int)

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            prompt_token_ids = torch.randint(
                llm.llm_engine.model_config.get_vocab_size(),
                size=(prompt_len, )).tolist()

            llm.llm_engine.add_request(
                request_id=f"seq{i}",
                prompt={'prompt_token_ids': prompt_token_ids},
                params=sampling_params)

    def abort_requests():
        for i in range(batch_size):
            llm.llm_engine.abort_request(f"seq{i}")

    # Warm up run
    print("Warm up run ...")
    add_requests()
    llm.llm_engine.step()  # Prefill
    llm.llm_engine.step()  # Decode
    abort_requests()

    print("Profile run ...")
    add_requests()

    with layerwise_profile() as prefill_prof:
        llm.llm_engine.step()  # First step is prefill

    decode_profs = []
261
262
263
264
265
    for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
        num_running_seqs = llm.llm_engine.scheduler[
            0].get_num_unfinished_seq_groups()
        with layerwise_profile(
                num_running_seqs=num_running_seqs) as decode_prof:
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            llm.llm_engine.step()
        decode_profs.append(decode_prof)

    decode_results_list = [prof.results for prof in decode_profs]
    prefill_results = prefill_prof.results
    has_decode = len(decode_results_list) > 0

    LINE_WIDTH = 80
    print("=" * LINE_WIDTH)
    print(f"= Prefill Model Table "
          f"(prompt_len={prompt_len}, batch_size={batch_size})")
    print("=" * LINE_WIDTH)
    print()
    prefill_results.print_model_table()

    if has_decode:
        print()
        print("=" * LINE_WIDTH)
        print(f"= First Decode Step Model Table "
              f"(prompt_len={prompt_len}, batch_size={batch_size})")
        print("=" * LINE_WIDTH)
        print()
        decode_results_list[0].print_model_table()

    print()
    print("=" * LINE_WIDTH)
    print(f"= Prefill Summary Table "
          f"(prompt_len={prompt_len}, batch_size={batch_size})")
    print("=" * LINE_WIDTH)
    print()
    prefill_results.print_summary_table()

    if has_decode:
        print()
        print("=" * LINE_WIDTH)
        print(f"= First Decode Step Summary Table "
              f"(prompt_len={prompt_len}, batch_size={batch_size})")
        print("=" * LINE_WIDTH)
        print()
        decode_results_list[0].print_summary_table()

    if csv_output:
308
309
        csv_filename_base = csv_output[:-4] \
                if csv_output.endswith('.csv') else csv_output
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        prefill_results.export_model_stats_table_csv(
            csv_filename_base + "_prefill_model_table.csv")
        prefill_results.export_summary_stats_table_csv(
            csv_filename_base + "_prefill_summary_table.csv")

        if has_decode:
            decode_results_list[0].export_model_stats_table_csv(\
                csv_filename_base + "_decode_model_table.csv")
            decode_results_list[0].export_summary_stats_table_csv(
                csv_filename_base + "_decode_summary_table.csv")

    if json_output:
        cuda_devices = [
            torch.cuda.get_device_properties(dev_idx)
            for dev_idx in range(torch.cuda.device_count())
        ]

        json_dict = {
            "context": {
                "python_version": f"{sys.version}",
                "torch_version": f"{torch.__version__}",
                "torch_cuda_version": f"{torch.version.cuda}",
                "cuda_devices": f"{cuda_devices}",
                **asdict(context)
            },
            "prefill": prefill_results.convert_stats_to_dict(),
        }

        if has_decode:
            for idx, dr in enumerate(decode_results_list):
                json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()

342
343
344
345
        # Add .json to json_output filename if it doesn't exist already.
        json_output_file = json_output if json_output.endswith(
            '.json') else json_output + '.json'
        with open(json_output_file, "w+") as f:
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
            json.dump(json_dict, f, indent=2)
        pass

    if context.save_chrome_traces_folder is not None:
        os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
        prefill_prof.profiler.export_chrome_trace(
            context.save_chrome_traces_folder + "/prefill.json")
        for idx, decode_prof in enumerate(decode_profs):
            decode_prof.profiler.export_chrome_trace(
                context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
        print("Traces saved as prefill.json and decode_1.json, etc."
              f" in folder {context.save_chrome_traces_folder}")


if __name__ == "__main__":
    parser = FlexibleArgumentParser(description="""
Profile a model

    example:
    ```
366
    python examples/offline_inference/offline_profile.py \\
367
368
        --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
        --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
369
        --enforce-eager run_num_steps -n 2
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
    ```

    then you can use various tools to analyze the json output
    terminal ascii tables:
        ```
        python tools/profiler/print_layerwise_table.py \\
            --json-trace Llama31-8b-FP8.json --phase prefill --table summary
        ```
    or create matplotlib stacked bar charts:
        ```
        python tools/profiler/visualize_layerwise_profile.py \\
            --json-trace Llama31-8b-FP8.json \\
            --output-directory profile_breakdown --plot-metric pct_cuda_time
        ```
""",
                                    formatter_class=RawTextHelpFormatter)
    parser.add_argument(
        "--csv",
        type=str,
        default=None,
        help="Export the results as multiple csv file. This should be the root "
        "filename, will create <filename>_prefill_model_table.csv, "
        "<filename>_prefill_summary_table.csv, "
        "<filename>_decode_model_table.csv, and "
        "<filename>_decode_summary_table.csv")
    parser.add_argument(
        "--json",
        type=str,
        default=None,
        help="Export the results as a json file. This should be the filename")
    parser.add_argument("--save-chrome-traces-folder",
                        type=str,
                        help="Save chrome traces for the prefill and decode "
                        "will save traces as prefill.json and decode_1.json, "
                        "etc. inside this folder")
    parser.add_argument(
        "--prompt-len",
        type=int,
        default=PROMPT_LEN_DEFAULT,
        help=f"Length of the random prompt to use when profiling, all batched "
        f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
    parser.add_argument("--batch-size",
                        type=int,
                        default=BATCH_SIZE_DEFAULT,
                        help=f"Number of requests to run as a single batch, "
                        f"default={BATCH_SIZE_DEFAULT}")
416
417
418
419
420
421
422
423
424

    subparsers = parser.add_subparsers(dest="cmd")

    run_num_steps_parser = subparsers.add_parser(
        "run_num_steps",
        help="This variation profiles n engine.step() invocations.")
    run_num_steps_parser.add_argument(
        '-n',
        '--num-steps',
425
        type=int,
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        help="Number of engine steps to profile.\n"
        "Setting it to 1, profiles only the prefill step.\n"
        "Setting it to 2, profiles the prefill and first decode step\n"
        "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
        "and so on ...")

    run_to_completion_parser = subparsers.add_parser(
        "run_to_completion",
        help="This variation profiles all the engine.step() invocations"
        "until the engine exhausts all submitted requests.")
    run_to_completion_parser.add_argument(
        '-n',
        '--complete-num-requests-per-step',
        type=int,
        help=
        "Complete complete_num_requests_per_step requests every decode step."
        "For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
        "the profiler is run for 6 engine steps, with the steps processing, "
        "128, 128, 96, 64, 32, 1 requests respectively.\n"
        "Note that we tack-on a one-request step at the end as it is often "
        "useful.")
447
448
449
450
451
452
453
454
455
456
457
458

    EngineArgs.add_cli_args(parser)

    args = parser.parse_args()
    context = ProfileContext(
        engine_args=EngineArgs.from_cli_args(args),
        **{
            k: v
            for k, v in vars(args).items()
            if k in inspect.signature(ProfileContext).parameters
        })
    run_profile(context, csv_output=args.csv, json_output=args.json)