graph_machete_bench.py 1.95 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
import math
import pickle
from collections import defaultdict

import matplotlib.pyplot as plt
import pandas as pd
10
import regex as re
11
12
13
14
15
16
17
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement

from vllm.utils import FlexibleArgumentParser

if __name__ == "__main__":
    parser = FlexibleArgumentParser(
18
19
20
21
        description="Benchmark the latency of processing a single batch of "
        "requests till completion."
    )
    parser.add_argument("filename", type=str)
22
23
24

    args = parser.parse_args()

25
    with open(args.filename, "rb") as f:
26
        data = pickle.load(f)
27
        raw_results: list[TMeasurement] = data["results"]
28
29

    results = defaultdict(lambda: list())
30
    for v in raw_results:
31
32
33
34
35
36
37
38
39
40
41
42
        result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
        if result is not None:
            KN = result.group(1)
        else:
            raise Exception("MKN not found")
        result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
        if result is not None:
            M = result.group(1)
        else:
            raise Exception("MKN not found")

        kernel = v.task_spec.description
43
        results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
44
45
46
47

    rows = int(math.ceil(len(results) / 2))
    fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
    axs = axs.flatten()
48
    for axs_idx, (shape, data) in enumerate(results.items()):
49
50
        plt.sca(axs[axs_idx])
        df = pd.DataFrame(data)
51
52
53
54
55
56
57
58
59
60
        sns.lineplot(
            data=df,
            x="batch_size",
            y="median",
            hue="kernel",
            style="kernel",
            markers=True,
            dashes=False,
            palette="Dark2",
        )
61
62
63
64
        plt.title(f"Shape: {shape}")
        plt.ylabel("time (median, s)")
    plt.tight_layout()
    plt.savefig("graph_machete_bench.pdf")