"tests/vscode:/vscode.git/clone" did not exist on "c123bc33f9c139b293e906d9612e247c8e29dbbf"
visualize_layerwise_profile.py 21.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
import argparse
import copy
import json
import math
import os
from pathlib import Path
10
from typing import Any, Optional
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

import matplotlib.pyplot as plt
import pandas as pd

## JSON parsing utils ####


def largest_dist_from_leaf(node: dict, depth: int = 0):
    if len(node["children"]) == 0:
        return depth
    return max([
        largest_dist_from_leaf(child, depth=depth + 1)
        for child in node["children"]
    ])


def get_entries_at_depth(depth: int,
28
                         entries_and_traces: list[tuple[Any, Any]],
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
                         node: dict,
                         curr_depth: int = 0,
                         trace=()):
    # assert that the query is at kernel or module level
    assert depth == -1 or depth == -2

    if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1):
        # The tree is not tall enough!
        entries_and_traces.append((node["entry"], trace))
        return

    if largest_dist_from_leaf(node) == (abs(depth) - 1):
        entries_and_traces.append((node["entry"], trace))

    trace = (node["entry"]["name"], ) + trace
    for child in node["children"]:
        get_entries_at_depth(depth,
                             entries_and_traces,
                             child,
                             curr_depth=curr_depth + 1,
                             trace=trace)


52
def fold_nodes(root: dict, nodes_to_fold: list[str]):
53

54
    stack: list[dict] = [root]
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    while len(stack) != 0:
        node = stack.pop()
        if node['entry']['name'] in nodes_to_fold:
            node["children"] = []
            continue
        for child in node["children"]:
            stack.append(child)
    return root


## Operation name cleanup utils ####


def trim_string_back(string: str, width: int) -> str:
    if len(string) > width:
        offset = len(string) - width + 3
        string = string[:-offset]
        if len(string) > 3:
            string = string + "..."
    return string


def shorten_plot_legend_strings(legend, max_char_len: int):
    for t in legend.get_texts():
        t.set_text(
            trim_string_back(abbreviate_known_names(t.get_text()),
                             max_char_len))


def abbreviate_known_names(name: str) -> str:
    abbreviations = {
        "MergedColumnParallelLinear": "MCPLinear",
        "QKVParallelLinear": "QKVPLinear",
        "RowParallelLinear": "RPLinear",
        "weight=": "w=",
        "bfloat16": "bf16",
        "float16": "f16",
    }
    for key, value in abbreviations.items():
        name = name.replace(key, value)
    return name


def attempt_to_make_names_unique(entries_and_traces):
    names, non_unique_names = (set(), set())

    def all_the_same(items) -> bool:
        return all(i == items[0] for i in items)

    for entry, _ in entries_and_traces:
        if entry["name"] in names:
            non_unique_names.add(entry["name"])
        else:
            names.add(entry["name"])

    for name in non_unique_names:
        entries_and_traces_with_name = [(entry, trace)
                                        for entry, trace in entries_and_traces
                                        if entry["name"] == name]

        zipped_traces = list(
            zip(*[trace for _, trace in entries_and_traces_with_name]))
        first_trace_difference = next(
            (i for i, trace_eles in enumerate(zipped_traces)
             if not all_the_same(trace_eles)), None)

        if first_trace_difference is None:
            # can't create a unique name, leave them names as the
            # are they will get aggregated by the pivot_table call
            continue

        for entry, trace in entries_and_traces_with_name:
            entry["name"] = " <- ".join((entry["name"], ) +
                                        trace[:first_trace_difference + 1])


## Operation grouping utils ####
'''
    Group operations in the given dataframe by some high-level ops like,
    - gemms
    - attention
    - rms_norm 
    etc.
'''


def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:

    def is_rms_norm(op_name: str):
        if "rms_norm_kernel" in op_name:
            return True

    def is_attention_block(op_name: str):
        if "flash_fwd" in op_name or \
            "reshape_and_cache_flash_kernel" in op_name:
            return True

    def is_quant(op_name: str):
        if "scaled_fp8_quant" in op_name or \
           "scaled_int8_quant" in op_name:
            return True

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    # LoRA ops
    def is_sgmv_shrink(op_name: str):
        return "sgmv_shrink" in op_name

    def is_sgmv_expand(op_name: str):
        return "sgmv_expand" in op_name

    def is_bgmv_shrink(op_name: str):
        return "bgmv_shrink" in op_name

    def is_bgmv_expand(op_name: str):
        return "bgmv_expand" in op_name

    def is_cutlass_gemm_op(op_name: str):
        return "void cutlass::Kernel" in op_name or \
           "void cutlass::device_kernel" in op_name

174
175
176
    def is_gemm_op(op_name: str):
        if is_quant(op_name):
            return False
177
178
        return is_cutlass_gemm_op(op_name) or \
           "xmma_gemm" in op_name  or \
179
180
           "gemv2T_kernel" in op_name or \
           "splitKreduce" in op_name or \
181
           "s16816gemm" in op_name
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

    def is_elementwise_op(op_name: str):
        return "elementwise_kernel" in op_name

    def is_mem_op(op_name: str):
        return "memcpy" in op_name.lower() or \
               "memset" in op_name.lower()

    def is_vocab_embedding_op(op_name: str):
        return "vocabparallelembed" in op_name.lower()

    # nccl ops
    def is_nccl_op(op_name: str):
        return "nccl" in op_name.lower()

    def is_nccl_all_reduce(op_name: str):
        return is_nccl_op(op_name) and \
                ("all_reduce" in op_name.lower() or \
                "allreduce" in op_name.lower())

    def is_nccl_gather(op_name: str):
        return is_nccl_op(op_name) and \
                "gather" in op_name.lower()

    def is_nccl_broadcast(op_name: str):
        return is_nccl_op(op_name) and \
                "broadcast" in op_name.lower()

    # Reduce ops types
    def is_cross_device_reduce_1stage(op_name: str):
        return "cross_device_reduce_1stage" in op_name

    def is_cross_device_reduce_2stage(op_name: str):
        return "cross_device_reduce_2stage" in op_name

217
218
    def is_custom_ar_all_reduce(op_name: str):
        return "_C_custom_ar::all_reduce" in op_name
219
220
221
222
223
224
225
226
227
228
229
230
231

    def is_reduce_kernel(op_name: str):
        return "reduce_kernel" in op_name

    headers = list(trace_df)
    ops = copy.deepcopy(headers)

    attention_ops = list(filter(lambda x: is_attention_block(x), ops))
    ops = list(filter(lambda x: x not in attention_ops, ops))

    quant_ops = list(filter(lambda x: is_quant(x), ops))
    ops = list(filter(lambda x: x not in quant_ops, ops))

232
233
234
235
236
237
238
239
240
241
242
243
    sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops))
    ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops))
    sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops))
    ops = list(filter(lambda x: x not in sgmv_expand_ops, ops))
    bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops))
    ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops))
    bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops))
    ops = list(filter(lambda x: x not in bgmv_expand_ops, ops))

    cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops))
    ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops))

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    gemm_ops = list(filter(lambda x: is_gemm_op(x), ops))
    ops = list(filter(lambda x: x not in gemm_ops, ops))

    rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops))
    ops = list(filter(lambda x: x not in rms_norm_ops, ops))

    vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops))
    ops = list(filter(lambda x: x not in vocab_embed_ops, ops))

    mem_ops = list(filter(lambda x: is_mem_op(x), ops))
    ops = list(filter(lambda x: x not in mem_ops, ops))

    elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
    ops = list(filter(lambda x: x not in elementwise_ops, ops))

    nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
    ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))

    nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
    ops = list(filter(lambda x: x not in nccl_gather_ops, ops))

    nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
    ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))

    nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
    ops = list(filter(lambda x: x not in nccl_other_ops, ops))

    cross_device_reduce_1stage_ops = list(
        filter(lambda x: is_cross_device_reduce_1stage(x), ops))
    ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))

    cross_device_reduce_2stage_ops = list(
        filter(lambda x: is_cross_device_reduce_2stage(x), ops))
    ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))

279
280
281
    custom_ar_all_reduce_ops = list(
        filter(lambda x: is_custom_ar_all_reduce(x), ops))
    ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
282
283
284
285
286
287
288
289

    reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
    ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))

    if len(attention_ops):
        trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
    if len(quant_ops):
        trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1)
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

    if len(sgmv_shrink_ops):
        trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum",
                                                                    axis=1)
    if len(sgmv_expand_ops):
        trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum",
                                                                    axis=1)
    if len(bgmv_shrink_ops):
        trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum",
                                                                    axis=1)
    if len(bgmv_expand_ops):
        trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum",
                                                                    axis=1)

    if len(cutlass_gemm_ops):
        trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum",
                                                                      axis=1)

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    if len(gemm_ops):
        trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1)
    if len(rms_norm_ops):
        trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1)
    if len(vocab_embed_ops):
        trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum",
                                                                    axis=1)
    if len(mem_ops):
        trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1)
    if len(elementwise_ops):
        trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
                                                                    axis=1)

    if len(nccl_all_reduce_ops):
        trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
            "sum", axis=1)
    if len(nccl_gather_ops):
        trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
                                                                    axis=1)
    if len(nccl_broadcast_ops):
        trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
            "sum", axis=1)
    if len(nccl_other_ops):
        trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
                                                                  axis=1)

    if len(cross_device_reduce_1stage_ops):
        trace_df['cross_device_reduce_1stage_ops'] = trace_df[
            cross_device_reduce_1stage_ops].agg("sum", axis=1)
    if len(cross_device_reduce_2stage_ops):
        trace_df['cross_device_reduce_2stage_ops'] = trace_df[
            cross_device_reduce_2stage_ops].agg("sum", axis=1)
340
341
342
    if len(custom_ar_all_reduce_ops):
        trace_df['custom_ar_all_reduce_ops'] = trace_df[
            custom_ar_all_reduce_ops].agg("sum", axis=1)
343
344
345
346
    if len(reduce_kernel_ops):
        trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
                                                                        axis=1)

347
348
349
    trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops +
                  sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops +
                  cutlass_gemm_ops + gemm_ops + rms_norm_ops +
350
351
352
353
354
355
356
                  vocab_embed_ops + mem_ops + elementwise_ops +
                  nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
                  nccl_other_ops + cross_device_reduce_1stage_ops +
                  cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
                  reduce_kernel_ops,
                  axis=1,
                  inplace=True)
357
358
359
360
361
362
363
364
365
366
367
    return trace_df


## Data plotting utils ####


def plot_trace_df(traces_df: pd.DataFrame,
                  plot_metric: str,
                  plot_title: str,
                  output: Optional[Path] = None):

368
369
370
371
372
373
    def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str:
        phase_df = traces_df.query(f'phase == "{phase}"')
        descs = phase_df['phase_desc'].to_list()
        assert all([desc == descs[0] for desc in descs])
        return descs[0]

374
    phases = traces_df['phase'].unique()
375
    phase_descs = [get_phase_description(traces_df, p) for p in phases]
376
377
378
379
380
381
382
383
    traces_df = traces_df.pivot_table(index="phase",
                                      columns="name",
                                      values=plot_metric,
                                      aggfunc="sum")

    traces_df = group_trace_by_operations(traces_df)

    # Make the figure
384
385
    fig_size_x = max(5, len(phases))
    fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True)
386
387
388
389
390
391
392

    # Draw the stacked bars
    ops = list(traces_df)
    bottom = [0] * len(phases)
    for op in ops:
        values = [traces_df[op][phase] for phase in phases]
        values = list(map(lambda x: 0.0 if math.isnan(x) else x, values))
393
        ax.bar(phase_descs, values, label=op, bottom=bottom)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        bottom = [bottom[j] + values[j] for j in range(len(phases))]

    # Write the values as text on the bars
    for bar in ax.patches:
        if bar.get_height() != 0:
            ax.text(bar.get_x() + bar.get_width() / 2,
                    bar.get_height() / 2 + bar.get_y(),
                    f"{round(bar.get_height(), 2)}",
                    ha='center',
                    color='w',
                    weight='bold',
                    size=5)

    # Setup legend
    handles, labels = plt.gca().get_legend_handles_labels()
    legend = fig.legend(handles,
                        labels,
                        loc='center left',
                        bbox_to_anchor=(1, 1))
    shorten_plot_legend_strings(legend, 50)

    # Setup labels and title
    plt.setp(ax.get_xticklabels(), rotation=90)
    ax.set_ylabel(plot_metric)
    plt.suptitle(plot_title)

    plt.savefig(output, bbox_inches='tight')
    print("Created: ", output)


def main(
        json_trace: Path,
        output_directory: Path,
        depth: int,  # Fetch/Plot operations at this depth of the Json tree
        plot_metric: str,
        make_names_unique: bool,
        top_k: int,
431
        json_nodes_to_fold: list[str]):
432

433
    def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame:
434
435

        def get_entries_and_traces(key: str):
436
            entries_and_traces: list[tuple[Any, Any]] = []
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            for root in profile_json[key]["summary_stats"]:
                # Fold nodes in the traces as per user request. i.e. simply
                # make the requested nodes leaf-nodes.
                root = fold_nodes(root, json_nodes_to_fold)
                get_entries_at_depth(depth, entries_and_traces, root)
            return entries_and_traces

        def keep_only_top_entries(df: pd.DataFrame,
                                  metric: str,
                                  top_k: int = 9) -> pd.DataFrame:
            df.loc[df.nsmallest(len(df) - top_k + 1, metric).index,
                   ["name"]] = "others"
            return df

451
452
453
454
455
456
457
458
        def get_phase_description(key: str) -> str:
            num_running_seqs = profile_json[key]['metadata'][
                'num_running_seqs']
            if num_running_seqs is not None:
                return f"{key}-seqs-{num_running_seqs}"
            else:
                return key

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        # Get data for each key
        traces = list(map(lambda x: get_entries_and_traces(x), step_keys))

        # Attempt some cleanup
        if make_names_unique:
            for trace in traces:
                attempt_to_make_names_unique(trace)

        # To pandas dataframe
        trace_dfs = list(
            map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0),
                traces))

        # Respect top_k
        if top_k:
            trace_dfs = list(
                map(
                    lambda trace_df: keep_only_top_entries(
                        trace_df, "cuda_time_us", top_k), trace_dfs))

        # Fill in information about the step-keys
        for trace_df, step_key in zip(trace_dfs, step_keys):
            trace_df['phase'] = step_key
482
            trace_df['phase_desc'] = get_phase_description(step_key)
483
484
485
486
487
488
489
490
491
492
493
494
495

        # Combine all data frames so they can be put in a single plot
        traces_df = pd.concat(trace_dfs)

        # Add a derived metric `cuda_time_ms`
        traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000
        traces_df = traces_df.fillna(0)

        return traces_df

    def make_plot_title_suffix(profile_json: dict) -> str:
        context = profile_json["context"]
        sparsity = context.get('sparsity', None)
496
497
498
499
500
        run_type = \
            f'Run {context["num_steps"]} steps' if context['num_steps'] else \
                (f'Complete {context["complete_num_requests_per_step"]} per '
                 f'step; Run till completion')
        return (f"{context['engine_args']['model']}\n"
501
502
                f"Batch={context['batch_size']}, "
                f"PromptLen={context['prompt_len']}, "
503
504
505
                f"NumGpus={context['engine_args']['tensor_parallel_size']}"
                f"{', Sparsity ' + sparsity if sparsity else ''}\n"
                f"Run Type: {run_type}")
506
507

    profile_json = None
508
    with open(json_trace) as f:
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        profile_json = json.load(f)
    assert profile_json is not None

    # Get all `llm.generate.step()` profile
    step_traces = list(profile_json.keys())
    assert (step_traces[0] == 'context')
    step_traces = step_traces[1:]  # have only prefill and decodes
    prefills = list(filter(lambda x: "prefill" in x, step_traces))
    all_decodes = list(filter(lambda x: "decode" in x, step_traces))
    assert len(prefills) + len(all_decodes) == len(step_traces)
    assert len(prefills) == 1

    decodes = all_decodes[::args.step_plot_interval]
    if decodes[-1] != all_decodes[-1]:
        # Always have the last decode
        decodes.append(all_decodes[-1])

    prefill_traces = prepare_data(profile_json, prefills)
    decode_traces = prepare_data(profile_json, decodes)

    plot_title_suffix = make_plot_title_suffix(profile_json)

    plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix,
                  output_directory / Path("prefill.png"))
    plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix,
                  output_directory / Path("decode_steps.png"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

540
541
542
543
    parser.add_argument("--json-trace",
                        type=str,
                        required=True,
                        help="json trace file output by \
544
                              examples/offline_inference/profiling.py")
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    parser.add_argument("--output-directory",
                        type=str,
                        required=False,
                        help="Directory to output plots")
    parser.add_argument("--level",
                        type=str,
                        default="module",
                        choices=["module", "kernel"])
    parser.add_argument("--top-k",
                        type=int,
                        default=12,
                        help="Only graph the top `top_k` entries by time.")
    parser.add_argument("--fold-json-node",
                        nargs='+',
                        default=['Sampler', 'LogitsProcessor'],
                        help='Do not plot the children of these nodes. Let, \
                              the node represent the aggregate of all its \
                              children')
    parser.add_argument("--plot-metric",
                        type=str,
                        default="cuda_time_ms",
                        help='Metric to plot. some options are cuda_time_ms, \
                                pct_cuda_time')
    parser.add_argument(
        "--step-plot-interval",
        type=int,
        default=4,
        help="For every `step_plot_interval` steps, plot 1 step")

    args = parser.parse_args()

    # Prepare/Extract relevant args
    make_names_unique = False
    if args.level == "module":
        depth = -2
        make_names_unique = True
    elif args.level == "kernel":
        depth = -1
    else:
        raise Exception(f"Unexpected level value ({args.level})")

    output_directory = args.output_directory if args.output_directory else Path(
        args.json_trace).parent

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    main(Path(args.json_trace), output_directory, depth, args.plot_metric,
         make_names_unique, args.top_k, args.fold_json_node)