graph_machete_bench.py 1.96 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
import math
import pickle
from collections import defaultdict

import matplotlib.pyplot as plt
import pandas as pd
10
import regex as re
11
12
13
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement

14
from vllm.utils.argparse_utils import FlexibleArgumentParser
15
16
17

if __name__ == "__main__":
    parser = FlexibleArgumentParser(
18
19
20
21
        description="Benchmark the latency of processing a single batch of "
        "requests till completion."
    )
    parser.add_argument("filename", type=str)
22
23
24

    args = parser.parse_args()

25
    with open(args.filename, "rb") as f:
26
        data = pickle.load(f)
27
        raw_results: list[TMeasurement] = data["results"]
28
29

    results = defaultdict(lambda: list())
30
    for v in raw_results:
31
32
33
34
35
36
37
38
39
40
41
42
        result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
        if result is not None:
            KN = result.group(1)
        else:
            raise Exception("MKN not found")
        result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
        if result is not None:
            M = result.group(1)
        else:
            raise Exception("MKN not found")

        kernel = v.task_spec.description
43
        results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
44
45
46
47

    rows = int(math.ceil(len(results) / 2))
    fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
    axs = axs.flatten()
48
    for axs_idx, (shape, data) in enumerate(results.items()):
49
50
        plt.sca(axs[axs_idx])
        df = pd.DataFrame(data)
51
52
53
54
55
56
57
58
59
60
        sns.lineplot(
            data=df,
            x="batch_size",
            y="median",
            hue="kernel",
            style="kernel",
            markers=True,
            dashes=False,
            palette="Dark2",
        )
61
62
63
64
        plt.title(f"Shape: {shape}")
        plt.ylabel("time (median, s)")
    plt.tight_layout()
    plt.savefig("graph_machete_bench.pdf")