"vscode:/vscode.git/clone" did not exist on "570d33437bf0b4ac42e00ad468ddc43f9e0b376f"
graph_machete_bench.py 1.95 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

laibao's avatar
laibao committed
4
5
6
7
8
9
import math
import pickle
from collections import defaultdict

import matplotlib.pyplot as plt
import pandas as pd
10
import regex as re
laibao's avatar
laibao committed
11
12
13
14
15
16
17
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement

from vllm.utils import FlexibleArgumentParser

if __name__ == "__main__":
    parser = FlexibleArgumentParser(
18
19
20
21
        description="Benchmark the latency of processing a single batch of "
        "requests till completion."
    )
    parser.add_argument("filename", type=str)
laibao's avatar
laibao committed
22
23
24

    args = parser.parse_args()

25
26
27
    with open(args.filename, "rb") as f:
        data = pickle.load(f)
        raw_results: list[TMeasurement] = data["results"]
laibao's avatar
laibao committed
28
29

    results = defaultdict(lambda: list())
30
    for v in raw_results:
laibao's avatar
laibao committed
31
32
33
34
35
36
37
38
39
40
41
42
        result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
        if result is not None:
            KN = result.group(1)
        else:
            raise Exception("MKN not found")
        result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
        if result is not None:
            M = result.group(1)
        else:
            raise Exception("MKN not found")

        kernel = v.task_spec.description
43
        results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
laibao's avatar
laibao committed
44
45
46
47
48
49
50

    rows = int(math.ceil(len(results) / 2))
    fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
    axs = axs.flatten()
    for axs_idx, (shape, data) in enumerate(results.items()):
        plt.sca(axs[axs_idx])
        df = pd.DataFrame(data)
51
52
53
54
55
56
57
58
59
60
        sns.lineplot(
            data=df,
            x="batch_size",
            y="median",
            hue="kernel",
            style="kernel",
            markers=True,
            dashes=False,
            palette="Dark2",
        )
laibao's avatar
laibao committed
61
62
63
64
        plt.title(f"Shape: {shape}")
        plt.ylabel("time (median, s)")
    plt.tight_layout()
    plt.savefig("graph_machete_bench.pdf")