graph_machete_bench.py 1.88 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
import math
import pickle
from collections import defaultdict

import matplotlib.pyplot as plt
import pandas as pd
9
import regex as re
10
11
12
13
14
15
16
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement

from vllm.utils import FlexibleArgumentParser

if __name__ == "__main__":
    parser = FlexibleArgumentParser(
17
18
19
20
        description="Benchmark the latency of processing a single batch of "
        "requests till completion."
    )
    parser.add_argument("filename", type=str)
21
22
23

    args = parser.parse_args()

24
    with open(args.filename, "rb") as f:
25
        data = pickle.load(f)
26
        raw_results: list[TMeasurement] = data["results"]
27
28

    results = defaultdict(lambda: list())
29
    for v in raw_results:
30
31
32
33
34
35
36
37
38
39
40
41
        result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
        if result is not None:
            KN = result.group(1)
        else:
            raise Exception("MKN not found")
        result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
        if result is not None:
            M = result.group(1)
        else:
            raise Exception("MKN not found")

        kernel = v.task_spec.description
42
        results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
43
44
45
46

    rows = int(math.ceil(len(results) / 2))
    fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
    axs = axs.flatten()
47
    for axs_idx, (shape, data) in enumerate(results.items()):
48
49
        plt.sca(axs[axs_idx])
        df = pd.DataFrame(data)
50
51
52
53
54
55
56
57
58
59
        sns.lineplot(
            data=df,
            x="batch_size",
            y="median",
            hue="kernel",
            style="kernel",
            markers=True,
            dashes=False,
            palette="Dark2",
        )
60
61
62
63
        plt.title(f"Shape: {shape}")
        plt.ylabel("time (median, s)")
    plt.tight_layout()
    plt.savefig("graph_machete_bench.pdf")